Add SPV_KHR_bfloat16 support (#6057)

* Add SPV_KHR_bfloat16 support

* Update DEPS to include SPIRV-Headers with bfloat16 support

* Fix unit test errors and format

* Add validation to invalid uses of bfloat16

* Add tests

* Roll back to previous commit

* Fix build error

* Add FPEncoding for opt::analysis::Float

* Address the comments

* Fix build error

* format

---------

Co-authored-by: Stu Smith <19190608+stu-s@users.noreply.github.com>
Co-authored-by: David Neto <dneto@google.com>
diff --git a/Android.mk b/Android.mk
index b2b96d9..1c150c4 100644
--- a/Android.mk
+++ b/Android.mk
@@ -76,7 +76,8 @@
 		source/val/validate_scopes.cpp \
 		source/val/validate_small_type_uses.cpp \
 		source/val/validate_tensor_layout.cpp \
-		source/val/validate_type.cpp
+		source/val/validate_type.cpp\
+		source/val/validate_invalid_type.cpp
 
 SPVTOOLS_OPT_SRC_FILES := \
 		source/opt/aggressive_dead_code_elim_pass.cpp \
diff --git a/BUILD.gn b/BUILD.gn
index 5fd00ca..82cf0e5 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -559,6 +559,7 @@
     "source/val/validate_small_type_uses.cpp",
     "source/val/validate_tensor_layout.cpp",
     "source/val/validate_type.cpp",
+    "source/val/validate_invalid_type.cpp",
     "source/val/validation_state.cpp",
     "source/val/validation_state.h",
   ]
diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt
index b5608f0..cc5bc53 100644
--- a/source/CMakeLists.txt
+++ b/source/CMakeLists.txt
@@ -336,6 +336,7 @@
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_small_type_uses.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_tensor_layout.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_type.cpp
+  ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_invalid_type.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/decoration.h
   ${CMAKE_CURRENT_SOURCE_DIR}/val/basic_block.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/construct.cpp
diff --git a/source/name_mapper.cpp b/source/name_mapper.cpp
index 24c24e2..ae3ef49 100644
--- a/source/name_mapper.cpp
+++ b/source/name_mapper.cpp
@@ -211,7 +211,12 @@
     } break;
     case spv::Op::OpTypeFloat: {
       const auto bit_width = inst.words[2];
-      // TODO: Handle optional fpencoding enum once actually used.
+      if (inst.num_words > 3) {
+        if (spv::FPEncoding(inst.words[3]) == spv::FPEncoding::BFloat16KHR) {
+          SaveName(result_id, "bfloat16");
+          break;
+        }
+      }
       switch (bit_width) {
         case 16:
           SaveName(result_id, "half");
diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp
index e0c4b80..e37b433 100644
--- a/source/opt/type_manager.cpp
+++ b/source/opt/type_manager.cpp
@@ -792,9 +792,13 @@
       type = new Integer(inst.GetSingleWordInOperand(0),
                          inst.GetSingleWordInOperand(1));
       break;
-    case spv::Op::OpTypeFloat:
-      type = new Float(inst.GetSingleWordInOperand(0));
-      break;
+    case spv::Op::OpTypeFloat: {
+      const spv::FPEncoding encoding =
+          inst.NumInOperands() > 1
+              ? static_cast<spv::FPEncoding>(inst.GetSingleWordInOperand(1))
+              : spv::FPEncoding::Max;
+      type = new Float(inst.GetSingleWordInOperand(0), encoding);
+    } break;
     case spv::Op::OpTypeVector:
       type = new Vector(GetType(inst.GetSingleWordInOperand(0)),
                         inst.GetSingleWordInOperand(1));
diff --git a/source/opt/types.cpp b/source/opt/types.cpp
index 2023719..a35f871 100644
--- a/source/opt/types.cpp
+++ b/source/opt/types.cpp
@@ -309,17 +309,26 @@
 
 bool Float::IsSameImpl(const Type* that, IsSameCache*) const {
   const Float* ft = that->AsFloat();
-  return ft && width_ == ft->width_ && HasSameDecorations(that);
+  return ft && width_ == ft->width_ && encoding_ == ft->encoding_ &&
+         HasSameDecorations(that);
 }
 
 std::string Float::str() const {
   std::ostringstream oss;
-  oss << "float" << width_;
+  switch (encoding_) {
+    case spv::FPEncoding::BFloat16KHR:
+      assert(width_ == 16);
+      oss << "bfloat16";
+      break;
+    default:
+      oss << "float" << width_;
+      break;
+  }
   return oss.str();
 }
 
 size_t Float::ComputeExtraStateHash(size_t hash, SeenTypes*) const {
-  return hash_combine(hash, width_);
+  return hash_combine(hash, width_, encoding_);
 }
 
 Vector::Vector(const Type* type, uint32_t count)
diff --git a/source/opt/types.h b/source/opt/types.h
index 1418331..99b3cd8 100644
--- a/source/opt/types.h
+++ b/source/opt/types.h
@@ -182,9 +182,9 @@
   // non-composite type.
   uint64_t NumberOfComponents() const;
 
-// A bunch of methods for casting this type to a given type. Returns this if the
-// cast can be done, nullptr otherwise.
-// clang-format off
+  // A bunch of methods for casting this type to a given type. Returns this if
+  // the cast can be done, nullptr otherwise.
+  // clang-format off
 #define DeclareCastMethod(target)                  \
   virtual target* As##target() { return nullptr; } \
   virtual const target* As##target() const { return nullptr; }
@@ -267,7 +267,8 @@
 
 class Float : public Type {
  public:
-  Float(uint32_t w) : Type(kFloat), width_(w) {}
+  Float(uint32_t w, spv::FPEncoding encoding = spv::FPEncoding::Max)
+      : Type(kFloat), width_(w), encoding_(encoding) {}
   Float(const Float&) = default;
 
   std::string str() const override;
@@ -275,13 +276,15 @@
   Float* AsFloat() override { return this; }
   const Float* AsFloat() const override { return this; }
   uint32_t width() const { return width_; }
+  spv::FPEncoding encoding() const { return encoding_; }
 
   size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
 
  private:
   bool IsSameImpl(const Type* that, IsSameCache*) const override;
 
-  uint32_t width_;  // bit width
+  uint32_t width_;            // bit width
+  spv::FPEncoding encoding_;  // FPEncoding
 };
 
 class Vector : public Type {
diff --git a/source/val/validate.cpp b/source/val/validate.cpp
index 2d10347..4c46d2b 100644
--- a/source/val/validate.cpp
+++ b/source/val/validate.cpp
@@ -367,6 +367,7 @@
     if (auto error = RayReorderNVPass(*vstate, &instruction)) return error;
     if (auto error = MeshShadingPass(*vstate, &instruction)) return error;
     if (auto error = TensorLayoutPass(*vstate, &instruction)) return error;
+    if (auto error = InvalidTypePass(*vstate, &instruction)) return error;
   }
 
   // Validate the preconditions involving adjacent instructions. e.g.
diff --git a/source/val/validate.h b/source/val/validate.h
index 5514ff7..5d13a7b 100644
--- a/source/val/validate.h
+++ b/source/val/validate.h
@@ -223,6 +223,9 @@
 /// Validates correctness of mesh shading instructions.
 spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst);
 
+/// Validates correctness of certain special type instructions.
+spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst);
+
 /// Calculates the reachability of basic blocks.
 void ReachabilityPass(ValidationState_t& _);
 
diff --git a/source/val/validate_arithmetics.cpp b/source/val/validate_arithmetics.cpp
index d252ec9..38281be 100644
--- a/source/val/validate_arithmetics.cpp
+++ b/source/val/validate_arithmetics.cpp
@@ -224,6 +224,14 @@
                << "Expected float scalar type as Result Type: "
                << spvOpcodeString(opcode);
 
+      if (_.IsBfloat16ScalarType(result_type)) {
+        if (!_.HasCapability(spv::Capability::BFloat16DotProductKHR)) {
+          return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                 << "OpDot Result Type <id> " << _.getIdName(result_type)
+                 << "requires BFloat16DotProductKHR be declared.";
+        }
+      }
+
       uint32_t first_vector_num_components = 0;
 
       for (size_t operand_index = 2; operand_index < inst->operands().size();
diff --git a/source/val/validate_capability.cpp b/source/val/validate_capability.cpp
index 81d2ad5..303142b 100644
--- a/source/val/validate_capability.cpp
+++ b/source/val/validate_capability.cpp
@@ -100,6 +100,7 @@
     case spv::Capability::GeometryStreams:
     case spv::Capability::Float16:
     case spv::Capability::Int8:
+    case spv::Capability::BFloat16TypeKHR:
       return true;
     default:
       break;
diff --git a/source/val/validate_conversion.cpp b/source/val/validate_conversion.cpp
index c459ec3..8bf87ad 100644
--- a/source/val/validate_conversion.cpp
+++ b/source/val/validate_conversion.cpp
@@ -264,11 +264,18 @@
                  << spvOpcodeString(opcode);
       }
 
-      if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type))
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Expected input to have different bit width from Result "
-                  "Type: "
-               << spvOpcodeString(opcode);
+      // Scalar type
+      const uint32_t resScalarType = _.GetComponentType(result_type);
+      const uint32_t inputScalartype = _.GetComponentType(input_type);
+      if (_.GetBitWidth(resScalarType) == _.GetBitWidth(inputScalartype))
+        if ((_.IsBfloat16ScalarType(resScalarType) &&
+             _.IsBfloat16ScalarType(inputScalartype)) ||
+            (!_.IsBfloat16ScalarType(inputScalartype) &&
+             !_.IsBfloat16ScalarType(resScalarType)))
+          return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                 << "Expected input to have different bit width from Result "
+                    "Type: "
+                 << spvOpcodeString(opcode);
       break;
     }
 
diff --git a/source/val/validate_interfaces.cpp b/source/val/validate_interfaces.cpp
index 68809b5..e01a08e 100644
--- a/source/val/validate_interfaces.cpp
+++ b/source/val/validate_interfaces.cpp
@@ -642,6 +642,28 @@
         has_callable_data = true;
         break;
       }
+      case spv::StorageClass::Input:
+      case spv::StorageClass::Output: {
+        auto result_type = _.FindDef(interface_var->type_id());
+        if (_.ContainsType(result_type->GetOperandAs<uint32_t>(2),
+                           [](const Instruction* inst) {
+                             if (inst &&
+                                 inst->opcode() == spv::Op::OpTypeFloat) {
+                               if (inst->words().size() > 3) {
+                                 if (inst->GetOperandAs<spv::FPEncoding>(2) ==
+                                     spv::FPEncoding::BFloat16KHR) {
+                                   return true;
+                                 }
+                               }
+                             }
+                             return false;
+                           })) {
+          return _.diag(SPV_ERROR_INVALID_ID, interface_var)
+                 << _.VkErrorID(10370) << "Bfloat16 OpVariable <id> "
+                 << _.getIdName(interface_var->id()) << " must not be declared "
+                 << "with a Storage Class of Input or Output.";
+        }
+      }
       default:
         break;
     }
diff --git a/source/val/validate_invalid_type.cpp b/source/val/validate_invalid_type.cpp
new file mode 100644
index 0000000..a9dcd29
--- /dev/null
+++ b/source/val/validate_invalid_type.cpp
@@ -0,0 +1,139 @@
+// Copyright (c) 2025 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Performs validation of invalid type instructions.
+
+#include <vector>
+
+#include "source/opcode.h"
+#include "source/val/instruction.h"
+#include "source/val/validate.h"
+#include "source/val/validation_state.h"
+
+namespace spvtools {
+namespace val {
+
+// Validates correctness of certain special type instructions.
+spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst) {
+  const spv::Op opcode = inst->opcode();
+
+  switch (opcode) {
+    // OpExtInst
+    case spv::Op::OpExtInst:
+    // Arithmetic Instructions
+    case spv::Op::OpFAdd:
+    case spv::Op::OpFSub:
+    case spv::Op::OpFMul:
+    case spv::Op::OpFDiv:
+    case spv::Op::OpFRem:
+    case spv::Op::OpFMod:
+    case spv::Op::OpFNegate:
+    // Derivative Instructions
+    case spv::Op::OpDPdx:
+    case spv::Op::OpDPdy:
+    case spv::Op::OpFwidth:
+    case spv::Op::OpDPdxFine:
+    case spv::Op::OpDPdyFine:
+    case spv::Op::OpFwidthFine:
+    case spv::Op::OpDPdxCoarse:
+    case spv::Op::OpDPdyCoarse:
+    case spv::Op::OpFwidthCoarse:
+    // Atomic Instructions
+    case spv::Op::OpAtomicFAddEXT:
+    case spv::Op::OpAtomicFMinEXT:
+    case spv::Op::OpAtomicFMaxEXT:
+    case spv::Op::OpAtomicLoad:
+    case spv::Op::OpAtomicExchange:
+    // Group and Subgroup Instructions
+    case spv::Op::OpGroupNonUniformRotateKHR:
+    case spv::Op::OpGroupNonUniformBroadcast:
+    case spv::Op::OpGroupNonUniformShuffle:
+    case spv::Op::OpGroupNonUniformShuffleXor:
+    case spv::Op::OpGroupNonUniformShuffleUp:
+    case spv::Op::OpGroupNonUniformShuffleDown:
+    case spv::Op::OpGroupNonUniformQuadBroadcast:
+    case spv::Op::OpGroupNonUniformQuadSwap:
+    case spv::Op::OpGroupNonUniformBroadcastFirst:
+    case spv::Op::OpGroupNonUniformFAdd:
+    case spv::Op::OpGroupNonUniformFMul:
+    case spv::Op::OpGroupNonUniformFMin: {
+      const uint32_t result_type = inst->type_id();
+      if (_.IsBfloat16ScalarType(result_type) ||
+          _.IsBfloat16VectorType(result_type)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << spvOpcodeString(opcode) << " doesn't support BFloat16 type.";
+      }
+      break;
+    }
+
+    case spv::Op::OpAtomicStore: {
+      uint32_t data_type =
+          _.FindDef(inst->GetOperandAs<uint32_t>(3))->type_id();
+      if (_.IsBfloat16VectorType(data_type)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << spvOpcodeString(opcode) << " doesn't support BFloat16 type.";
+      }
+      break;
+    }
+    // Relational and Logical Instructions
+    case spv::Op::OpIsNan:
+    case spv::Op::OpIsInf:
+    case spv::Op::OpIsFinite:
+    case spv::Op::OpIsNormal:
+    case spv::Op::OpSignBitSet: {
+      const uint32_t operand_type = _.GetOperandTypeId(inst, 2);
+      if (_.IsBfloat16ScalarType(operand_type) ||
+          _.IsBfloat16VectorType(operand_type)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << spvOpcodeString(opcode) << " doesn't support BFloat16 type.";
+      }
+      break;
+    }
+
+    case spv::Op::OpGroupNonUniformAllEqual: {
+      const auto value_type = _.GetOperandTypeId(inst, 3);
+      if (_.IsBfloat16ScalarType(value_type) ||
+          _.IsBfloat16VectorType(value_type)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << spvOpcodeString(opcode) << " doesn't support BFloat16 type.";
+      }
+      break;
+    }
+
+    case spv::Op::OpMatrixTimesMatrix: {
+      const uint32_t result_type = inst->type_id();
+      uint32_t res_num_rows = 0;
+      uint32_t res_num_cols = 0;
+      uint32_t res_col_type = 0;
+      uint32_t res_component_type = 0;
+      if (_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols,
+                              &res_col_type, &res_component_type)) {
+        if (_.IsBfloat16ScalarType(res_component_type)) {
+          return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                 << spvOpcodeString(opcode)
+                 << " doesn't support BFloat16 type.";
+        }
+      }
+      break;
+    }
+
+    default:
+      break;
+  }
+
+  return SPV_SUCCESS;
+}
+
+}  // namespace val
+}  // namespace spvtools
diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp
index b5fa04f..4377271 100644
--- a/source/val/validate_type.cpp
+++ b/source/val/validate_type.cpp
@@ -115,6 +115,15 @@
   if (num_bits == 32) {
     return SPV_SUCCESS;
   }
+  auto operands = inst->words();
+  if (operands.size() > 3) {
+    if (operands[3] != 0) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Current FPEncoding only supports BFloat16KHR.";
+    }
+    return SPV_SUCCESS;
+  }
+
   if (num_bits == 16) {
     if (_.features().declare_float16_type) {
       return SPV_SUCCESS;
@@ -643,6 +652,15 @@
            << " is not a scalar numerical type.";
   }
 
+  if (_.IsBfloat16ScalarType(component_type_id)) {
+    if (!_.HasCapability(spv::Capability::BFloat16CooperativeMatrixKHR)) {
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
+             << "OpTypeCooperativeMatrix Component Type <id> "
+             << _.getIdName(component_type_id)
+             << "require BFloat16CooperativeMatrixKHR be declared.";
+    }
+  }
+
   const auto scope_index = 2;
   const auto scope_id = inst->GetOperandAs<uint32_t>(scope_index);
   const auto scope = _.FindDef(scope_id);
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index db7e020..96d80f0 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -946,6 +946,32 @@
   return inst && inst->opcode() == spv::Op::OpTypeVoid;
 }
 
+bool ValidationState_t::IsBfloat16ScalarType(uint32_t id) const {
+  const Instruction* inst = FindDef(id);
+  if (inst && inst->opcode() == spv::Op::OpTypeFloat) {
+    if (inst->words().size() > 3) {
+      if (inst->GetOperandAs<spv::FPEncoding>(2) ==
+          spv::FPEncoding::BFloat16KHR) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+bool ValidationState_t::IsBfloat16VectorType(uint32_t id) const {
+  const Instruction* inst = FindDef(id);
+  if (!inst) {
+    return false;
+  }
+
+  if (inst->opcode() == spv::Op::OpTypeVector) {
+    return IsBfloat16ScalarType(GetComponentType(id));
+  }
+
+  return false;
+}
+
 bool ValidationState_t::IsFloatScalarType(uint32_t id) const {
   const Instruction* inst = FindDef(id);
   return inst && inst->opcode() == spv::Op::OpTypeFloat;
@@ -1768,6 +1794,10 @@
 
   const auto f = [type, width](const Instruction* inst) {
     if (inst->opcode() == type) {
+      // Bfloat16 is a special type.
+      if (type == spv::Op::OpTypeFloat && inst->words().size() > 3)
+        return false;
+
       return inst->GetOperandAs<uint32_t>(1u) == width;
     }
     return false;
@@ -2536,6 +2566,8 @@
     case 10213:
       // This use to be a standalone, but maintenance8 will set allow_offset_texture_operand now
       return VUID_WRAP(VUID-RuntimeSpirv-Offset-10213);
+    case 10370:
+      return VUID_WRAP(VUID-StandaloneSpirv-OpTypeFloat-10370);
     case 10583:
       return VUID_WRAP(VUID-StandaloneSpirv-Component-10583);
     case 10684:
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index e97d3d3..8cc87a4 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -633,6 +633,8 @@
   // Returns true iff |id| is a type corresponding to the name of the function.
   // Only works for types not for objects.
   bool IsVoidType(uint32_t id) const;
+  bool IsBfloat16ScalarType(uint32_t id) const;
+  bool IsBfloat16VectorType(uint32_t id) const;
   bool IsFloatScalarType(uint32_t id) const;
   bool IsFloatArrayType(uint32_t id) const;
   bool IsFloatVectorType(uint32_t id) const;
diff --git a/test/opt/types_test.cpp b/test/opt/types_test.cpp
index 4ceeb14..01c8e90 100644
--- a/test/opt/types_test.cpp
+++ b/test/opt/types_test.cpp
@@ -364,6 +364,20 @@
   }
 }
 
+TEST(Types, FloatFPEncoding) {
+  std::vector<spv::FPEncoding> encodings = {
+      spv::FPEncoding::BFloat16KHR,
+      spv::FPEncoding::Max,
+  };
+  std::vector<std::unique_ptr<Float>> types;
+  for (spv::FPEncoding encoding : encodings) {
+    types.emplace_back(new Float(16, encoding));
+  }
+  for (size_t i = 0; i < encodings.size(); i++) {
+    EXPECT_EQ(encodings[i], types[i]->encoding());
+  }
+}
+
 TEST(Types, VectorElementCount) {
   auto s32 = MakeUnique<Integer>(32, true);
   for (uint32_t c : {2, 3, 4}) {
diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt
index 9d6f6ea..c356753 100644
--- a/test/val/CMakeLists.txt
+++ b/test/val/CMakeLists.txt
@@ -38,6 +38,7 @@
        val_derivatives_test.cpp
        val_entry_point_test.cpp
        val_explicit_reserved_test.cpp
+       val_invalid_type_test.cpp
        val_extensions_test.cpp
        val_extension_spv_khr_expect_assume_test.cpp
        val_extension_spv_khr_linkonce_odr_test.cpp
diff --git a/test/val/val_arithmetics_test.cpp b/test/val/val_arithmetics_test.cpp
index 2a15b28..85653c5 100644
--- a/test/val/val_arithmetics_test.cpp
+++ b/test/val/val_arithmetics_test.cpp
@@ -1166,6 +1166,50 @@
                 "vector size of the right operand: OuterProduct"));
 }
 
+std::string GenerateBFloatCode(const std::string& main_body) {
+  const std::string prefix =
+      R"(
+OpCapability Shader
+OpCapability BFloat16TypeKHR
+OpCapability BFloat16DotProductKHR
+OpCapability BFloat16CooperativeMatrixKHR
+OpExtension "SPV_KHR_bfloat16"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpSource GLSL 450
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bfloat16 = OpTypeFloat 16 BFloat16KHR
+%_ptr_Function_bfloat16 = OpTypePointer Function %bfloat16
+%v2bfloat16 = OpTypeVector %bfloat16 2
+%_ptr_Function_v2bfloat16 = OpTypePointer Function %v2bfloat16
+%main = OpFunction %void None %func
+%main_entry = OpLabel)";
+
+  const std::string suffix =
+      R"(
+OpReturn
+OpFunctionEnd)";
+
+  return prefix + main_body + suffix;
+}
+
+TEST_F(ValidateArithmetics, DotBfloat16) {
+  const std::string body = R"(
+%v1 = OpVariable %_ptr_Function_v2bfloat16 Function
+%v2 = OpVariable %_ptr_Function_v2bfloat16 Function
+%12 = OpLoad %v2bfloat16 %v1
+%14 = OpLoad %v2bfloat16 %v2
+%15 = OpDot %bfloat16 %12 %14
+)";
+
+  CompileSuccessfully(GenerateBFloatCode(body).c_str());
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
 std::string GenerateCoopMatCode(const std::string& extra_types,
                                 const std::string& main_body) {
   const std::string prefix =
diff --git a/test/val/val_conversion_test.cpp b/test/val/val_conversion_test.cpp
index 52d40d3..f0f5c46 100644
--- a/test/val/val_conversion_test.cpp
+++ b/test/val/val_conversion_test.cpp
@@ -39,6 +39,7 @@
   const std::string capabilities =
       R"(
 OpCapability Shader
+OpCapability Float16
 OpCapability Int64
 OpCapability Float64)";
 
@@ -54,6 +55,7 @@
 %func = OpTypeFunction %void
 %bool = OpTypeBool
 %f32 = OpTypeFloat 32
+%f16 = OpTypeFloat 16
 %u32 = OpTypeInt 32 0
 %s32 = OpTypeInt 32 1
 %f64 = OpTypeFloat 64
@@ -84,6 +86,8 @@
 %f32_3 = OpConstant %f32 3
 %f32_4 = OpConstant %f32 4
 
+%f16_1 = OpConstant %f16 1
+
 %s32_0 = OpConstant %s32 0
 %s32_1 = OpConstant %s32 1
 %s32_2 = OpConstant %s32 2
@@ -593,6 +597,24 @@
                         "Result Type: FConvert"));
 }
 
+TEST_F(ValidateConversion, FConvertFloat16ToBFloat16) {
+  const std::string extensions = R"(
+OpCapability BFloat16TypeKHR
+OpExtension "SPV_KHR_bfloat16"
+)";
+
+  const std::string types = R"(
+%bf16 = OpTypeFloat 16 BFloat16KHR
+)";
+
+  const std::string body = R"(
+%val = OpFConvert %bf16 %f16_1
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body, extensions, "", types).c_str());
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
 TEST_F(ValidateConversion, QuantizeToF16Success) {
   const std::string body = R"(
 %val1 = OpQuantizeToF16 %f32 %f32_1
diff --git a/test/val/val_data_test.cpp b/test/val/val_data_test.cpp
index 349e5e9..a6b90fd 100644
--- a/test/val/val_data_test.cpp
+++ b/test/val/val_data_test.cpp
@@ -71,6 +71,15 @@
      OpCapability Int64
      OpMemoryModel Logical GLSL450
 )";
+std::string header_with_bfloat16 = R"(
+     OpCapability Shader
+     OpCapability Linkage
+     OpCapability BFloat16TypeKHR
+     OpCapability BFloat16DotProductKHR
+     OpCapability BFloat16CooperativeMatrixKHR
+     OpExtension "SPV_KHR_bfloat16"
+     OpMemoryModel Logical GLSL450
+)";
 std::string header_with_float16 = R"(
      OpCapability Shader
      OpCapability Linkage
@@ -334,6 +343,25 @@
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
+TEST_F(ValidateData, bfloat16_good) {
+  std::string str = header_with_bfloat16 + "%2 = OpTypeFloat 16 BFloat16KHR";
+  CompileSuccessfully(str.c_str());
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateData, cooperative_matrix_bfloat16_good) {
+  std::string str = header_with_bfloat16 + R"(
+%u32 = OpTypeInt 32 0
+%u32_16 = OpConstant %u32 16
+%useA = OpConstant %u32 0
+%subgroup = OpConstant %u32 3
+%bf16 = OpTypeFloat 16 BFloat16KHR
+%bf16matA = OpTypeCooperativeMatrixKHR %bf16 %subgroup %u32_16 %u32_16 %useA
+)";
+  CompileSuccessfully(str.c_str());
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
 TEST_F(ValidateData, float16_buffer_good) {
   std::string str = header_with_float16_buffer + "%2 = OpTypeFloat 16";
   CompileSuccessfully(str.c_str());
@@ -347,8 +375,50 @@
   EXPECT_THAT(getDiagnosticString(), HasSubstr(missing_float16_cap_error));
 }
 
+TEST_F(ValidateData, bfloat16_bad) {
+  std::string str = header + "%2 = OpTypeFloat 16 BFloat16KHR";
+  CompileSuccessfully(str.c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("requires one of these capabilities: BFloat16TypeKHR"));
+}
+
+TEST_F(ValidateData, dot_bfloat16_bad) {
+  std::string str = R"(
+               OpCapability Shader
+               OpCapability BFloat16TypeKHR
+               OpExtension "SPV_KHR_bfloat16"
+          %1 = OpExtInstImport "GLSL.std.450"
+                OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main"
+               OpExecutionMode %main LocalSize 1 1 1
+               OpSource GLSL 450
+               OpName %main "main"
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+      %bfloat16 = OpTypeFloat 16 BFloat16KHR
+%_ptr_Function_bfloat16 = OpTypePointer Function %bfloat16
+    %v2bfloat16 = OpTypeVector %bfloat16 2
+%_ptr_Function_v2bfloat16 = OpTypePointer Function %v2bfloat16
+       %main = OpFunction %void None %3
+          %5 = OpLabel
+         %v1 = OpVariable %_ptr_Function_v2bfloat16 Function
+         %v2 = OpVariable %_ptr_Function_v2bfloat16 Function
+         %12 = OpLoad %v2bfloat16 %v1
+         %14 = OpLoad %v2bfloat16 %v2
+         %15 = OpDot %bfloat16 %12 %14
+               OpReturn
+               OpFunctionEnd
+)";
+  CompileSuccessfully(str.c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("requires BFloat16DotProductKHR be declared."));
+}
+
 TEST_F(ValidateData, float64_good) {
   std::string str = header_with_float64 + "%2 = OpTypeFloat 64";
+
   CompileSuccessfully(str.c_str());
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
@@ -627,7 +697,8 @@
         OpExtension "SPV_KHR_variable_pointers"
         OpExtension "SPV_KHR_16bit_storage"
         OpMemoryModel Logical GLSL450
-        OpDecorate %_ FPRoundingMode )" + mode + R"(
+        OpDecorate %_ FPRoundingMode )" +
+                        mode + R"(
         %half = OpTypeFloat 16
         %float = OpTypeFloat 32
         %float_1_25 = OpConstant %float 1.25
diff --git a/test/val/val_interfaces_test.cpp b/test/val/val_interfaces_test.cpp
index 2aa8033..6dff110 100644
--- a/test/val/val_interfaces_test.cpp
+++ b/test/val/val_interfaces_test.cpp
@@ -2022,6 +2022,36 @@
   EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_3));
 }
 
+TEST_F(ValidateInterfacesTest, ValidInstructionType) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability BFloat16TypeKHR
+OpExtension "SPV_KHR_bfloat16"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %in %out
+OpExecutionMode %main OriginUpperLeft
+OpDecorate %in Location 0
+OpDecorate %out Location 0
+%void = OpTypeVoid
+%bfloat16 = OpTypeFloat 16 BFloat16KHR
+%in_ptr = OpTypePointer Input %bfloat16
+%out_ptr = OpTypePointer Output %bfloat16
+%in = OpVariable %in_ptr Input
+%out = OpVariable %out_ptr Output
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text, SPV_ENV_VULKAN_1_3);
+  ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_3));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Bfloat16 OpVariable <id> '2[%2]' must not be declared "
+                        "with a Storage Class of Input or Output.\n"));
+}
+
 }  // namespace
 }  // namespace val
 }  // namespace spvtools
diff --git a/test/val/val_invalid_type_test.cpp b/test/val/val_invalid_type_test.cpp
new file mode 100644
index 0000000..dc911e1
--- /dev/null
+++ b/test/val/val_invalid_type_test.cpp
@@ -0,0 +1,117 @@
+// Copyright (c) 2025 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests for invalid types.
+
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "source/spirv_target_env.h"
+#include "test/unit_spirv.h"
+#include "test/val/val_fixtures.h"
+
+namespace spvtools {
+namespace val {
+namespace {
+
+using ::testing::HasSubstr;
+using ::testing::Not;
+using ::testing::Values;
+
+using ValidateInvalidType = spvtest::ValidateBase<bool>;
+
+std::string GenerateBFloatCode(const std::string& main_body) {
+  const std::string prefix =
+      R"(
+OpCapability Shader
+OpCapability BFloat16TypeKHR
+OpCapability AtomicFloat16AddEXT
+OpCapability GroupNonUniformShuffle
+OpExtension "SPV_EXT_shader_atomic_float16_add"
+OpExtension "SPV_KHR_bfloat16"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpSource GLSL 450
+OpName %main "main"
+%void = OpTypeVoid
+%bfloat16 = OpTypeFloat 16 BFloat16KHR
+%func = OpTypeFunction %void
+%u32 = OpTypeInt 32 0
+%u1 = OpConstant %u32 1
+%u0 = OpConstant %u32 0
+%u3 = OpConstant %u32 3
+%bf16_1 = OpConstant %bfloat16 1
+%_ptr_Function_bfloat16 = OpTypePointer Function %bfloat16
+%v2bfloat16 = OpTypeVector %bfloat16 2
+%_ptr_Function_v2bfloat16 = OpTypePointer Function %v2bfloat16
+%bf16_ptr = OpTypePointer Workgroup %bfloat16
+%bf16_var = OpVariable %bf16_ptr Workgroup
+%main = OpFunction %void None %func
+%main_entry = OpLabel)";
+
+  const std::string suffix =
+      R"(
+OpReturn
+OpFunctionEnd)";
+
+  return prefix + main_body + suffix;
+}
+
+TEST_F(ValidateInvalidType, Bfloat16InvalidArithmeticInstruction) {
+  const std::string body = R"(
+%v1 = OpVariable %_ptr_Function_bfloat16 Function
+%v2 = OpVariable %_ptr_Function_bfloat16 Function
+%12 = OpLoad %bfloat16 %v1
+%14 = OpLoad %bfloat16 %v2
+%15 = OpFMul %bfloat16 %12 %14
+)";
+
+  CompileSuccessfully(GenerateBFloatCode(body).c_str(), SPV_ENV_VULKAN_1_3);
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA,
+            ValidateInstructions(SPV_ENV_UNIVERSAL_1_6));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("FMul doesn't support BFloat16 type."));
+}
+
+TEST_F(ValidateInvalidType, Bfloat16InvalidAtomicInstruction) {
+  const std::string body = R"(
+%val1 = OpAtomicFAddEXT %bfloat16 %bf16_var %u1 %u0 %bf16_1
+)";
+
+  CompileSuccessfully(GenerateBFloatCode(body).c_str(), SPV_ENV_VULKAN_1_3);
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA,
+            ValidateInstructions(SPV_ENV_UNIVERSAL_1_6));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("AtomicFAddEXT doesn't support BFloat16 type."));
+}
+
+TEST_F(ValidateInvalidType, Bfloat16InvalidGroupNonUniformShuffle) {
+  const std::string body = R"(
+%val1 = OpGroupNonUniformShuffle %bfloat16 %u3 %bf16_1 %u0
+)";
+
+  CompileSuccessfully(GenerateBFloatCode(body).c_str(), SPV_ENV_VULKAN_1_3);
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA,
+            ValidateInstructions(SPV_ENV_UNIVERSAL_1_6));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("GroupNonUniformShuffle doesn't support BFloat16 type."));
+}
+
+}  // namespace
+}  // namespace val
+}  // namespace spvtools