add support of Intel(R) AMX ISA
diff --git a/xbyak/xbyak.h b/xbyak/xbyak.h
index 63efccd..c5399c4 100644
--- a/xbyak/xbyak.h
+++ b/xbyak/xbyak.h
@@ -397,7 +397,7 @@
class Operand {
static const uint8 EXT8BIT = 0x20;
unsigned int idx_:6; // 0..31 + EXT8BIT = 1 if spl/bpl/sil/dil
- unsigned int kind_:9;
+ unsigned int kind_:10;
unsigned int bit_:10;
protected:
unsigned int zero_:1;
@@ -415,7 +415,8 @@
YMM = 1 << 5,
ZMM = 1 << 6,
OPMASK = 1 << 7,
- BNDREG = 1 << 8
+ BNDREG = 1 << 8,
+ TMM = 1 << 9
};
enum Code {
#ifdef XBYAK64
@@ -445,6 +446,7 @@
bool isXMM() const { return is(XMM); }
bool isYMM() const { return is(YMM); }
bool isZMM() const { return is(ZMM); }
+ bool isTMM() const { return is(TMM); }
bool isXMEM() const { return is(XMM | MEM); }
bool isYMEM() const { return is(YMM | MEM); }
bool isZMEM() const { return is(ZMM | MEM); }
@@ -463,9 +465,9 @@
int getRounding() const { return rounding_; }
void setKind(Kind kind)
{
- if ((kind & (XMM|YMM|ZMM)) == 0) return;
+ if ((kind & (XMM|YMM|ZMM|TMM)) == 0) return;
kind_ = kind;
- bit_ = kind == XMM ? 128 : kind == YMM ? 256 : 512;
+ bit_ = kind == XMM ? 128 : kind == YMM ? 256 : kind == ZMM ? 512 : 8192;
}
// err if MMX/FPU/OPMASK/BNDREG
void setBit(int bit);
@@ -513,6 +515,11 @@
} else if (isOPMASK()) {
static const char *tbl[8] = { "k0", "k1", "k2", "k3", "k4", "k5", "k6", "k7" };
return tbl[idx];
+ } else if (isTMM()) {
+ static const char *tbl[8] = {
+ "tmm0", "tmm1", "tmm2", "tmm3", "tmm4", "tmm5", "tmm6", "tmm7"
+ };
+ return tbl[idx];
} else if (isZMM()) {
static const char *tbl[32] = {
"zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15",
@@ -552,13 +559,13 @@
inline void Operand::setBit(int bit)
{
- if (bit != 8 && bit != 16 && bit != 32 && bit != 64 && bit != 128 && bit != 256 && bit != 512) goto ERR;
+ if (bit != 8 && bit != 16 && bit != 32 && bit != 64 && bit != 128 && bit != 256 && bit != 512 && bit != 8192) goto ERR;
if (isBit(bit)) return;
if (is(MEM | OPMASK)) {
bit_ = bit;
return;
}
- if (is(REG | XMM | YMM | ZMM)) {
+ if (is(REG | XMM | YMM | ZMM | TMM)) {
int idx = getIdx();
// err if converting ah, bh, ch, dh
if (isREG(8) && (4 <= idx && idx < 8) && !isExt8bit()) goto ERR;
@@ -580,6 +587,7 @@
case 128: kind = XMM; break;
case 256: kind = YMM; break;
case 512: kind = ZMM; break;
+ case 8192: kind = TMM; break;
}
idx_ = idx;
kind_ = kind;
@@ -674,6 +682,10 @@
Zmm operator|(const EvexModifierRounding& emr) const { Zmm r(*this); r.setRounding(emr.rounding); return r; }
};
+struct Tmm : public Reg {
+ explicit Tmm(int idx = 0, Kind kind = Operand::TMM, int bit = 8192) : Reg(idx, kind, bit) { }
+};
+
struct Opmask : public Reg {
explicit Opmask(int idx = 0) : Reg(idx, Operand::OPMASK, 64) {}
};
@@ -782,7 +794,7 @@
: scale_(scale)
, disp_(0)
{
- if (!r.isREG(i32e) && !r.is(Reg::XMM|Reg::YMM|Reg::ZMM)) throw Error(ERR_BAD_SIZE_OF_REGISTER);
+ if (!r.isREG(i32e) && !r.is(Reg::XMM|Reg::YMM|Reg::ZMM|Reg::TMM)) throw Error(ERR_BAD_SIZE_OF_REGISTER);
if (scale == 0) return;
if (scale != 1 && scale != 2 && scale != 4 && scale != 8) throw Error(ERR_BAD_SCALE);
if (r.getBit() >= 128 || scale != 1) { // xmm/ymm is always index
@@ -1583,6 +1595,7 @@
T_M_K = 1 << 28, // mem{k}
T_VSIB = 1 << 29,
T_MEM_EVEX = 1 << 30, // use evex if mem
+ T_TMM = 1 << 31,
T_XXX
};
void vex(const Reg& reg, const Reg& base, const Operand *v, int type, int code, bool x = false)
@@ -2250,6 +2263,19 @@
}
throw Error(ERR_BAD_COMBINATION);
}
+ void opAMX(const Tmm& t1, const Operand& op1, const Operand& op2, int type, int code0, int imm8 = NONE)
+ {
+ const Reg *t2 = static_cast<const Reg*>(&op1);
+ const Operand *op = &op2;
+ if (op2.isNone()) { // <i>(t1, op1) -> <i>(t1, t1, op1)
+ t2 = &t1;
+ op = &op1;
+ }
+ // <i>(t1, t2, op)
+ if (!((type & T_TMM) && (t1.isTMM() && t2->isTMM()))) throw Error(ERR_BAD_COMBINATION);
+
+ opVex(t1, t2, *op, type, code0, imm8);
+ }
public:
unsigned int getVersion() const { return VERSION; }
using CodeArray::db;
@@ -2285,6 +2311,7 @@
const Zmm zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15;
const Zmm zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23;
const Zmm zmm24, zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31;
+ const Tmm tmm0, tmm1, tmm2, tmm3, tmm4, tmm5, tmm6, tmm7;
const Xmm &xm8, &xm9, &xm10, &xm11, &xm12, &xm13, &xm14, &xm15; // for my convenience
const Xmm &xm16, &xm17, &xm18, &xm19, &xm20, &xm21, &xm22, &xm23;
const Xmm &xm24, &xm25, &xm26, &xm27, &xm28, &xm29, &xm30, &xm31;
@@ -2294,6 +2321,7 @@
const Zmm &zm8, &zm9, &zm10, &zm11, &zm12, &zm13, &zm14, &zm15;
const Zmm &zm16, &zm17, &zm18, &zm19, &zm20, &zm21, &zm22, &zm23;
const Zmm &zm24, &zm25, &zm26, &zm27, &zm28, &zm29, &zm30, &zm31;
+ const Tmm &tm0, &tm1, &tm2, &tm3, &tm4, &tm5, &tm6, &tm7;
const RegRip rip;
#endif
#ifndef XBYAK_DISABLE_SEGMENT
@@ -2566,6 +2594,7 @@
, zmm8(8), zmm9(9), zmm10(10), zmm11(11), zmm12(12), zmm13(13), zmm14(14), zmm15(15)
, zmm16(16), zmm17(17), zmm18(18), zmm19(19), zmm20(20), zmm21(21), zmm22(22), zmm23(23)
, zmm24(24), zmm25(25), zmm26(26), zmm27(27), zmm28(28), zmm29(29), zmm30(30), zmm31(31)
+ , tmm0(0), tmm1(1), tmm2(2), tmm3(3), tmm4(4), tmm5(5), tmm6(6), tmm7(7)
// for my convenience
, xm8(xmm8), xm9(xmm9), xm10(xmm10), xm11(xmm11), xm12(xmm12), xm13(xmm13), xm14(xmm14), xm15(xmm15)
, xm16(xmm16), xm17(xmm17), xm18(xmm18), xm19(xmm19), xm20(xmm20), xm21(xmm21), xm22(xmm22), xm23(xmm23)
@@ -2576,6 +2605,7 @@
, zm8(zmm8), zm9(zmm9), zm10(zmm10), zm11(zmm11), zm12(zmm12), zm13(zmm13), zm14(zmm14), zm15(zmm15)
, zm16(zmm16), zm17(zmm17), zm18(zmm18), zm19(zmm19), zm20(zmm20), zm21(zmm21), zm22(zmm22), zm23(zmm23)
, zm24(zmm24), zm25(zmm25), zm26(zmm26), zm27(zmm27), zm28(zmm28), zm29(zmm29), zm30(zmm30), zm31(zmm31)
+ , tm0(tmm0), tm1(tmm1), tm2(tmm2), tm3(tmm3), tm4(tmm4), tm5(tmm5), tm6(tmm6), tm7(tmm7)
, rip()
#endif
#ifndef XBYAK_DISABLE_SEGMENT
@@ -2702,6 +2732,7 @@
static const Zmm zmm8(8), zmm9(9), zmm10(10), zmm11(11), zmm12(12), zmm13(13), zmm14(14), zmm15(15);
static const Zmm zmm16(16), zmm17(17), zmm18(18), zmm19(19), zmm20(20), zmm21(21), zmm22(22), zmm23(23);
static const Zmm zmm24(24), zmm25(25), zmm26(26), zmm27(27), zmm28(28), zmm29(29), zmm30(30), zmm31(31);
+static const Tmm tmm0(0), tmm1(1), tmm2(2), tmm3(3), tmm4(4), tmm5(5), tmm6(6), tmm7(7);
static const RegRip rip;
#endif
#ifndef XBYAK_DISABLE_SEGMENT
diff --git a/xbyak/xbyak_util.h b/xbyak/xbyak_util.h
index 4f79d8f..4246367 100644
--- a/xbyak/xbyak_util.h
+++ b/xbyak/xbyak_util.h
@@ -353,6 +353,9 @@
static const Type tAVX512_VPOPCNTDQ = uint64(1) << 56;
static const Type tAVX512_BF16 = uint64(1) << 57;
static const Type tAVX512_VP2INTERSECT = uint64(1) << 58;
+ static const Type tAMX_TILE = uint64(1) << 59;
+ static const Type tAMX_INT8 = uint64(1) << 60;
+ static const Type tAMX_BF16 = uint64(1) << 61;
Cpu()
: type_(NONE)
@@ -456,6 +459,9 @@
if (EBX & (1U << 14)) type_ |= tMPX;
if (EBX & (1U << 29)) type_ |= tSHA;
if (ECX & (1U << 0)) type_ |= tPREFETCHWT1;
+ if (EDX & (1U << 24)) type_ |= tAMX_TILE;
+ if (EDX & (1U << 25)) type_ |= tAMX_INT8;
+ if (EDX & (1U << 22)) type_ |= tAMX_BF16;
}
setFamily();
setNumCores();