Added support for D3D12_HEAP_TYPE_CUSTOM to custom pools. Unfinished.

Replaced member POOL_DESC::HeapType with HeapProperties. Compatibility breaking!
Incrased constant HEAP_TYPE_COUNT to 4.
diff --git a/src/D3D12MemAlloc.cpp b/src/D3D12MemAlloc.cpp
index f641c96..9d416c2 100644
--- a/src/D3D12MemAlloc.cpp
+++ b/src/D3D12MemAlloc.cpp
@@ -484,6 +484,7 @@
     case D3D12_HEAP_TYPE_DEFAULT:  return 0;

     case D3D12_HEAP_TYPE_UPLOAD:   return 1;

     case D3D12_HEAP_TYPE_READBACK: return 2;

+    case D3D12_HEAP_TYPE_CUSTOM:   return 3;

     default: D3D12MA_ASSERT(0); return UINT_MAX;

     }

 }

@@ -492,6 +493,7 @@
     L"DEFAULT",

     L"UPLOAD",

     L"READBACK",

+    L"CUSTOM",

 };

 

 // Stat helper functions

@@ -733,13 +735,6 @@
     return result;

 }

 

-static inline bool IsHeapTypeValid(D3D12_HEAP_TYPE type)

-{

-    return type == D3D12_HEAP_TYPE_DEFAULT ||

-        type == D3D12_HEAP_TYPE_UPLOAD ||

-        type == D3D12_HEAP_TYPE_READBACK;

-}

-

 ////////////////////////////////////////////////////////////////////////////////

 // Private class Vector

 

@@ -2458,14 +2453,14 @@
 public:

     MemoryBlock(

         AllocatorPimpl* allocator,

-        D3D12_HEAP_TYPE heapType,

+        const D3D12_HEAP_PROPERTIES& heapProps,

         D3D12_HEAP_FLAGS heapFlags,

         UINT64 size,

         UINT id);

     virtual ~MemoryBlock();

     // Creates the ID3D12Heap.

 

-    D3D12_HEAP_TYPE GetHeapType() const { return m_HeapType; }

+    const D3D12_HEAP_PROPERTIES& GetHeapProperties() const { return m_HeapProps; }

     D3D12_HEAP_FLAGS GetHeapFlags() const { return m_HeapFlags; }

     UINT64 GetSize() const { return m_Size; }

     UINT GetId() const { return m_Id; }

@@ -2473,7 +2468,7 @@
 

 protected:

     AllocatorPimpl* const m_Allocator;

-    const D3D12_HEAP_TYPE m_HeapType;

+    const D3D12_HEAP_PROPERTIES m_HeapProps;

     const D3D12_HEAP_FLAGS m_HeapFlags;

     const UINT64 m_Size;

     const UINT m_Id;

@@ -2502,7 +2497,7 @@
     NormalBlock(

         AllocatorPimpl* allocator,

         BlockVector* blockVector,

-        D3D12_HEAP_TYPE heapType,

+        const D3D12_HEAP_PROPERTIES& heapProps,

         D3D12_HEAP_FLAGS heapFlags,

         UINT64 size,

         UINT id);

@@ -2535,7 +2530,7 @@
 public:

     BlockVector(

         AllocatorPimpl* hAllocator,

-        D3D12_HEAP_TYPE heapType,

+        const D3D12_HEAP_PROPERTIES& heapProps,

         D3D12_HEAP_FLAGS heapFlags,

         UINT64 preferredBlockSize,

         size_t minBlockCount,

@@ -2545,7 +2540,7 @@
 

     HRESULT CreateMinBlocks();

 

-    UINT GetHeapType() const { return m_HeapType; }

+    const D3D12_HEAP_PROPERTIES& GetHeapProperties() const { return m_HeapProps; }

     UINT64 GetPreferredBlockSize() const { return m_PreferredBlockSize; }

 

     bool IsEmpty();

@@ -2594,7 +2589,7 @@
 

 private:

     AllocatorPimpl* const m_hAllocator;

-    const D3D12_HEAP_TYPE m_HeapType;

+    const D3D12_HEAP_PROPERTIES m_HeapProps;

     const D3D12_HEAP_FLAGS m_HeapFlags;

     const UINT64 m_PreferredBlockSize;

     const size_t m_MinBlockCount;

@@ -2639,6 +2634,7 @@
 ////////////////////////////////////////////////////////////////////////////////

 // Private class AllocatorPimpl definition

 

+static const UINT DEFAULT_POOL_HEAP_TYPE_COUNT = 3; // Only DEFAULT, UPLOAD, READBACK.

 static const UINT DEFAULT_POOL_MAX_COUNT = 9;

 

 struct CurrentBudgetData

@@ -3631,11 +3627,11 @@
 NormalBlock::NormalBlock(

     AllocatorPimpl* allocator,

     BlockVector* blockVector,

-    D3D12_HEAP_TYPE heapType,

+    const D3D12_HEAP_PROPERTIES& heapProps,

     D3D12_HEAP_FLAGS heapFlags,

     UINT64 size,

     UINT id) :

-    MemoryBlock(allocator, heapType, heapFlags, size, id),

+    MemoryBlock(allocator, heapProps, heapFlags, size, id),

     m_pMetadata(NULL),

     m_BlockVector(blockVector)

 {

@@ -3681,12 +3677,12 @@
 

 MemoryBlock::MemoryBlock(

     AllocatorPimpl* allocator,

-    D3D12_HEAP_TYPE heapType,

+    const D3D12_HEAP_PROPERTIES& heapProps,

     D3D12_HEAP_FLAGS heapFlags,

     UINT64 size,

     UINT id) :

     m_Allocator(allocator),

-    m_HeapType(heapType),

+    m_HeapProps(heapProps),

     m_HeapFlags(heapFlags),

     m_Size(size),

     m_Id(id)

@@ -3697,7 +3693,7 @@
 {

     if(m_Heap)

     {

-        m_Allocator->m_Budget.m_BlockBytes[HeapTypeToIndex(m_HeapType)] -= m_Size;

+        m_Allocator->m_Budget.m_BlockBytes[HeapTypeToIndex(m_HeapProps.Type)] -= m_Size;

         m_Heap->Release();

     }

 }

@@ -3708,14 +3704,14 @@
 

     D3D12_HEAP_DESC heapDesc = {};

     heapDesc.SizeInBytes = m_Size;

-    heapDesc.Properties.Type = m_HeapType;

+    heapDesc.Properties = m_HeapProps;

     heapDesc.Alignment = HeapFlagsToAlignment(m_HeapFlags);

     heapDesc.Flags = m_HeapFlags;

 

     HRESULT hr = m_Allocator->GetDevice()->CreateHeap(&heapDesc, __uuidof(*m_Heap), (void**)&m_Heap);

     if(SUCCEEDED(hr))

     {

-        m_Allocator->m_Budget.m_BlockBytes[HeapTypeToIndex(m_HeapType)] += m_Size;

+        m_Allocator->m_Budget.m_BlockBytes[HeapTypeToIndex(m_HeapProps.Type)] += m_Size;

     }

     return hr;

 }

@@ -3725,14 +3721,14 @@
 

 BlockVector::BlockVector(

     AllocatorPimpl* hAllocator,

-    D3D12_HEAP_TYPE heapType,

+    const D3D12_HEAP_PROPERTIES& heapProps,

     D3D12_HEAP_FLAGS heapFlags,

     UINT64 preferredBlockSize,

     size_t minBlockCount,

     size_t maxBlockCount,

     bool explicitBlockSize) :

     m_hAllocator(hAllocator),

-    m_HeapType(heapType),

+    m_HeapProps(heapProps),

     m_HeapFlags(heapFlags),

     m_PreferredBlockSize(preferredBlockSize),

     m_MinBlockCount(minBlockCount),

@@ -3823,10 +3819,11 @@
         return E_OUTOFMEMORY;

     }

 

-    UINT64 freeMemory;

+    UINT64 freeMemory = UINT64_MAX;

+    if(m_HeapProps.Type != D3D12_HEAP_TYPE_CUSTOM)

     {

         Budget budget = {};

-        m_hAllocator->GetBudgetForHeapType(budget, m_HeapType);

+        m_hAllocator->GetBudgetForHeapType(budget, m_HeapProps.Type);

         freeMemory = (budget.UsageBytes < budget.BudgetBytes) ? (budget.BudgetBytes - budget.UsageBytes) : 0;

     }

 

@@ -3938,9 +3935,10 @@
     NormalBlock* pBlockToDelete = NULL;

 

     bool budgetExceeded = false;

+    if(m_HeapProps.Type != D3D12_HEAP_TYPE_CUSTOM)

     {

         Budget budget = {};

-        m_hAllocator->GetBudgetForHeapType(budget, m_HeapType);

+        m_hAllocator->GetBudgetForHeapType(budget, m_HeapProps.Type);

         budgetExceeded = budget.UsageBytes >= budget.BudgetBytes;

     }

 

@@ -4173,7 +4171,10 @@
         pBlock->m_pMetadata->Alloc(currRequest, size, *pAllocation);

         (*pAllocation)->InitPlaced(currRequest.offset, alignment, pBlock);

         D3D12MA_HEAVY_ASSERT(pBlock->Validate());

-        m_hAllocator->m_Budget.AddAllocation(HeapTypeToIndex(m_HeapType), size);

+        if(m_HeapProps.Type != D3D12_HEAP_TYPE_CUSTOM)

+        {

+            m_hAllocator->m_Budget.AddAllocation(HeapTypeToIndex(m_HeapProps.Type), size);

+        }

         return S_OK;

     }

     return E_OUTOFMEMORY;

@@ -4184,7 +4185,7 @@
     NormalBlock* const pBlock = D3D12MA_NEW(m_hAllocator->GetAllocs(), NormalBlock)(

         m_hAllocator,

         this,

-        m_HeapType,

+        m_HeapProps,

         m_HeapFlags,

         blockSize,

         m_NextBlockId++);

@@ -4304,7 +4305,7 @@
 

 void BlockVector::AddStats(Stats& outStats)

 {

-    const UINT heapTypeIndex = HeapTypeToIndex(m_HeapType);

+    const UINT heapTypeIndex = HeapTypeToIndex(m_HeapProps.Type);

     StatInfo* const pStatInfo = &outStats.HeapType[heapTypeIndex];

 

     MutexLockRead lock(m_Mutex, m_hAllocator->UseMutex());

@@ -4359,7 +4360,7 @@
     UINT maxBlockCount = desc.MaxBlockCount != 0 ? desc.MaxBlockCount : UINT_MAX;

 

     m_BlockVector = D3D12MA_NEW(allocator->GetAllocs(), BlockVector)(

-        allocator, desc.HeapType, heapFlags,

+        allocator, desc.HeapProperties, heapFlags,

         preferredBlockSize,

         desc.MinBlockCount, maxBlockCount,

         explicitBlockSize);

@@ -4460,7 +4461,7 @@
 

 Pool::~Pool()

 {

-    m_Pimpl->GetAllocator()->UnregisterPool(this, m_Pimpl->GetDesc().HeapType);

+    m_Pimpl->GetAllocator()->UnregisterPool(this, m_Pimpl->GetDesc().HeapProperties.Type);

 

     D3D12MA_DELETE(m_Pimpl->GetAllocator()->GetAllocs(), m_Pimpl);

 }

@@ -4523,16 +4524,16 @@
     m_D3D12Options.ResourceHeapTier = (D3D12MA_FORCE_RESOURCE_HEAP_TIER);

 #endif

 

+    D3D12_HEAP_PROPERTIES heapProps = {};

     const UINT defaultPoolCount = CalcDefaultPoolCount();

     for(UINT i = 0; i < defaultPoolCount; ++i)

     {

-        D3D12_HEAP_TYPE heapType;

         D3D12_HEAP_FLAGS heapFlags;

-        CalcDefaultPoolParams(heapType, heapFlags, i);

+        CalcDefaultPoolParams(heapProps.Type, heapFlags, i);

 

         m_BlockVectors[i] = D3D12MA_NEW(GetAllocs(), BlockVector)(

             this, // hAllocator

-            heapType, // heapType

+            heapProps, // heapType

             heapFlags, // heapFlags

             m_PreferredBlockSize,

             0, // minBlockCount

@@ -4618,11 +4619,6 @@
         *ppvResource = NULL;

     }

 

-    if(pAllocDesc->CustomPool == NULL && !IsHeapTypeValid(pAllocDesc->HeapType))

-    {

-        return E_INVALIDARG;

-    }

-

     ALLOCATION_DESC finalAllocDesc = *pAllocDesc;

 

     D3D12_RESOURCE_DESC finalResourceDesc = *pResourceDesc;

@@ -4796,11 +4792,6 @@
         return E_NOINTERFACE;

     }

 

-    if(pAllocDesc->CustomPool == NULL && !IsHeapTypeValid(pAllocDesc->HeapType))

-    {

-        return E_INVALIDARG;

-    }

-

     ALLOCATION_DESC finalAllocDesc = *pAllocDesc;

 

     D3D12_RESOURCE_DESC1 finalResourceDesc = *pResourceDesc;

@@ -4931,11 +4922,6 @@
     }

     else

     {

-        if(!IsHeapTypeValid(pAllocDesc->HeapType))

-        {

-            return E_INVALIDARG;

-        }

-

         ALLOCATION_DESC finalAllocDesc = *pAllocDesc;

 

         const UINT defaultPoolIndex = CalcDefaultPoolIndex(*pAllocDesc);

@@ -5003,11 +4989,6 @@
         return E_INVALIDARG;

     }

 

-    if(!IsHeapTypeValid(pAllocDesc->HeapType))

-    {

-        return E_INVALIDARG;

-    }

-

     return AllocateHeap1(pAllocDesc, *pAllocInfo, pProtectedSession, ppAllocation);

 }

 #endif // #ifdef __ID3D12Device4_INTERFACE_DEFINED__

@@ -5056,7 +5037,7 @@
     D3D12_HEAP_FLAGS heapFlags,

     UINT64 minBytes)

 {

-    if(!IsHeapTypeValid(heapType))

+    if(heapType != D3D12_HEAP_TYPE_DEFAULT && heapType != D3D12_HEAP_TYPE_UPLOAD && heapType != D3D12_HEAP_TYPE_READBACK)

     {

         D3D12MA_ASSERT(0 && "Allocator::SetDefaultHeapMinBytes: Invalid heapType passed.");

         return E_INVALIDARG;

@@ -5592,7 +5573,7 @@
     D3D12MA_ASSERT(block);

     BlockVector* const blockVector = block->GetBlockVector();

     D3D12MA_ASSERT(blockVector);

-    m_Budget.RemoveAllocation(HeapTypeToIndex(block->GetHeapType()), allocation->GetSize());

+    m_Budget.RemoveAllocation(HeapTypeToIndex(block->GetHeapProperties().Type), allocation->GetSize());

     blockVector->Free(allocation);

 }

 

@@ -5636,7 +5617,7 @@
     

     if(SupportsResourceHeapTier2())

     {

-        for(size_t heapTypeIndex = 0; heapTypeIndex < HEAP_TYPE_COUNT; ++heapTypeIndex)

+        for(size_t heapTypeIndex = 0; heapTypeIndex < DEFAULT_POOL_HEAP_TYPE_COUNT; ++heapTypeIndex)

         {

             BlockVector* const pBlockVector = m_BlockVectors[heapTypeIndex];

             D3D12MA_ASSERT(pBlockVector);

@@ -5645,7 +5626,7 @@
     }

     else

     {

-        for(size_t heapTypeIndex = 0; heapTypeIndex < HEAP_TYPE_COUNT; ++heapTypeIndex)

+        for(size_t heapTypeIndex = 0; heapTypeIndex < DEFAULT_POOL_HEAP_TYPE_COUNT; ++heapTypeIndex)

         {

             for(size_t heapSubType = 0; heapSubType < 3; ++heapSubType)

             {

@@ -5865,7 +5846,7 @@
 

             if (SupportsResourceHeapTier2())

             {

-                for (size_t heapType = 0; heapType < HEAP_TYPE_COUNT; ++heapType)

+                for (size_t heapType = 0; heapType < DEFAULT_POOL_HEAP_TYPE_COUNT; ++heapType)

                 {

                     json.WriteString(HeapTypeNames[heapType]);

                     json.BeginObject();

@@ -5881,7 +5862,7 @@
             }

             else

             {

-                for (size_t heapType = 0; heapType < HEAP_TYPE_COUNT; ++heapType)

+                for (size_t heapType = 0; heapType < DEFAULT_POOL_HEAP_TYPE_COUNT; ++heapType)

                 {

                     for (size_t heapSubType = 0; heapSubType < 3; ++heapSubType)

                     {

@@ -6420,7 +6401,6 @@
     Pool** ppPool)

 {

     if(!pPoolDesc || !ppPool ||

-        !IsHeapTypeValid(pPoolDesc->HeapType) ||

         (pPoolDesc->MaxBlockCount > 0 && pPoolDesc->MaxBlockCount < pPoolDesc->MinBlockCount))

     {

         D3D12MA_ASSERT(0 && "Invalid arguments passed to Allocator::CreatePool.");

@@ -6436,7 +6416,7 @@
     HRESULT hr = (*ppPool)->m_Pimpl->Init();

     if(SUCCEEDED(hr))

     {

-        m_Pimpl->RegisterPool(*ppPool, pPoolDesc->HeapType);

+        m_Pimpl->RegisterPool(*ppPool, pPoolDesc->HeapProperties.Type);

     }

     else

     {

diff --git a/src/D3D12MemAlloc.h b/src/D3D12MemAlloc.h
index d340b81..affeee6 100644
--- a/src/D3D12MemAlloc.h
+++ b/src/D3D12MemAlloc.h
@@ -637,6 +637,13 @@
   are not going to be included into this repository.

 */

 

+// If using this library on a platform different than Windows PC, you should

+// include D3D12-compatible header before this library on your own and define this macro.

+#ifndef D3D12MA_D3D12_HEADERS_ALREADY_INCLUDED

+    #include <d3d12.h>

+    #include <dxgi1_6.h>

+#endif

+

 // Define this macro to 0 to disable usage of DXGI 1.4 (needed for IDXGIAdapter3 and query for memory budget).

 #ifndef D3D12MA_DXGI_1_4

     #ifdef __IDXGIAdapter3_INTERFACE_DEFINED__

@@ -646,13 +653,6 @@
     #endif

 #endif

 

-// If using this library on a platform different than Windows PC, you should

-// include D3D12-compatible header before this library on your own and define this macro.

-#ifndef D3D12MA_D3D12_HEADERS_ALREADY_INCLUDED

-    #include <d3d12.h>

-    #include <dxgi.h>

-#endif

-

 /*

 When defined to value other than 0, the library will try to use

 D3D12_SMALL_RESOURCE_PLACEMENT_ALIGNMENT or D3D12_SMALL_MSAA_RESOURCE_PLACEMENT_ALIGNMENT

@@ -970,11 +970,12 @@
 /// \brief Parameters of created D3D12MA::Pool object. To be used with D3D12MA::Allocator::CreatePool.

 struct POOL_DESC

 {

-    /** \brief The type of memory heap where allocations of this pool should be placed.

+    /** \brief The parameters of memory heap where allocations of this pool should be placed.

 

-    It must be one of: `D3D12_HEAP_TYPE_DEFAULT`, `D3D12_HEAP_TYPE_UPLOAD`, `D3D12_HEAP_TYPE_READBACK`.

+    In the simplest case, just fill it with zeros and set `Type` to one of: `D3D12_HEAP_TYPE_DEFAULT`,

+    `D3D12_HEAP_TYPE_UPLOAD`, `D3D12_HEAP_TYPE_READBACK`. Additional parameters can be used e.g. to utilize UMA.

     */

-    D3D12_HEAP_TYPE HeapType;

+    D3D12_HEAP_PROPERTIES HeapProperties;

     /** \brief Heap flags to be used when allocating heaps of this pool.

 

     It should contain one of these values, depending on type of resources you are going to create in this heap:

@@ -1128,7 +1129,7 @@
 /**

 \brief Number of D3D12 memory heap types supported.

 */

-const UINT HEAP_TYPE_COUNT = 3;

+const UINT HEAP_TYPE_COUNT = 4;

 

 /**

 \brief Calculated statistics of memory usage in entire allocator.

@@ -1162,7 +1163,7 @@
     StatInfo Total;

     /**

     One StatInfo for each type of heap located at the following indices:

-    0 - DEFAULT, 1 - UPLOAD, 2 - READBACK.

+    0 - DEFAULT, 1 - UPLOAD, 2 - READBACK, 3 - CUSTOM.

     */

     StatInfo HeapType[HEAP_TYPE_COUNT];

 };

diff --git a/src/Tests.cpp b/src/Tests.cpp
index eb609bc..7a65ad5 100644
--- a/src/Tests.cpp
+++ b/src/Tests.cpp
@@ -572,7 +572,7 @@
     // # Create pool, 1..2 blocks of 11 MB

     

     D3D12MA::POOL_DESC poolDesc = {};

-    poolDesc.HeapType = D3D12_HEAP_TYPE_DEFAULT;

+    poolDesc.HeapProperties.Type = D3D12_HEAP_TYPE_DEFAULT;

     poolDesc.HeapFlags = D3D12_HEAP_FLAG_ALLOW_ONLY_BUFFERS;

     poolDesc.BlockSize = 11 * MEGABYTE;

     poolDesc.MinBlockCount = 1;