Change logic of Allocator::CreateResource so that an interface other than ID3D12Resource can be requested

Wrote test for it.
diff --git a/src/D3D12MemAlloc.cpp b/src/D3D12MemAlloc.cpp
index 808efcc..bf8b43d 100644
--- a/src/D3D12MemAlloc.cpp
+++ b/src/D3D12MemAlloc.cpp
@@ -29,6 +29,7 @@
     #endif

 #endif

 

+#include <combaseapi.h>

 #include <mutex>

 #include <atomic>

 #include <algorithm>

@@ -3634,15 +3635,21 @@
             &resourceDesc,

             InitialResourceState,

             pOptimizedClearValue,

-            riidResource,

-            (void**)&res);

+            IID_PPV_ARGS(&res));

         if(SUCCEEDED(hr))

         {

-            (*ppAllocation)->SetResource(res, &resourceDesc);

             if(ppvResource != NULL)

             {

-                res->AddRef();

-                *ppvResource = res;

+                hr = res->QueryInterface(riidResource, ppvResource);

+            }

+            if(SUCCEEDED(hr))

+            {

+                (*ppAllocation)->SetResource(res, &resourceDesc);

+            }

+            else

+            {

+                res->Release();

+                SAFE_RELEASE(*ppAllocation);

             }

         }

         else

@@ -4515,31 +4522,37 @@
     D3D12_HEAP_PROPERTIES heapProps = {};

     heapProps.Type = pAllocDesc->HeapType;

 

-    D3D12_HEAP_FLAGS heapFlags = pAllocDesc->ExtraHeapFlags;

+    const D3D12_HEAP_FLAGS heapFlags = pAllocDesc->ExtraHeapFlags;

 

     ID3D12Resource* res = NULL;

     HRESULT hr = m_Device->CreateCommittedResource(

         &heapProps, heapFlags, pResourceDesc, InitialResourceState,

-        pOptimizedClearValue, riidResource, (void**)&res);

+        pOptimizedClearValue, IID_PPV_ARGS(&res));

     if(SUCCEEDED(hr))

     {

-        const BOOL wasZeroInitialized = TRUE;

-        Allocation* alloc = m_AllocationObjectAllocator.Allocate(this, resAllocInfo.SizeInBytes, wasZeroInitialized);

-        alloc->InitCommitted(pAllocDesc->HeapType);

-        alloc->SetResource(res, pResourceDesc);

-

-        *ppAllocation = alloc;

         if(ppvResource != NULL)

         {

-            res->AddRef();

-            *ppvResource = res;

+            hr = res->QueryInterface(riidResource, ppvResource);

         }

+        if(SUCCEEDED(hr))

+        {

+            const BOOL wasZeroInitialized = TRUE;

+            Allocation* alloc = m_AllocationObjectAllocator.Allocate(this, resAllocInfo.SizeInBytes, wasZeroInitialized);

+            alloc->InitCommitted(pAllocDesc->HeapType);

+            alloc->SetResource(res, pResourceDesc);

 

-        RegisterCommittedAllocation(*ppAllocation, pAllocDesc->HeapType);

+            *ppAllocation = alloc;

 

-        const UINT heapTypeIndex = HeapTypeToIndex(pAllocDesc->HeapType);

-        m_Budget.AddAllocation(heapTypeIndex, resAllocInfo.SizeInBytes);

-        m_Budget.m_BlockBytes[heapTypeIndex] += resAllocInfo.SizeInBytes;

+            RegisterCommittedAllocation(*ppAllocation, pAllocDesc->HeapType);

+

+            const UINT heapTypeIndex = HeapTypeToIndex(pAllocDesc->HeapType);

+            m_Budget.AddAllocation(heapTypeIndex, resAllocInfo.SizeInBytes);

+            m_Budget.m_BlockBytes[heapTypeIndex] += resAllocInfo.SizeInBytes;

+        }

+        else

+        {

+            res->Release();

+        }

     }

     return hr;

 }

@@ -5359,8 +5372,7 @@
 

 void Allocation::SetResource(ID3D12Resource* resource, const D3D12_RESOURCE_DESC* pResourceDesc)

 {

-    D3D12MA_ASSERT(m_Resource == NULL);

-    D3D12MA_ASSERT(pResourceDesc);

+    D3D12MA_ASSERT(m_Resource == NULL && pResourceDesc);

     m_Resource = resource;

     m_PackedData.SetResourceDimension(pResourceDesc->Dimension);

     m_PackedData.SetResourceFlags(pResourceDesc->Flags);

@@ -5435,7 +5447,7 @@
     REFIID riidResource,

     void** ppvResource)

 {

-    if(!pAllocDesc || !pResourceDesc || !ppAllocation || riidResource == IID_NULL)

+    if(!pAllocDesc || !pResourceDesc || !ppAllocation)

     {

         D3D12MA_ASSERT(0 && "Invalid arguments passed to Allocator::CreateResource.");

         return E_INVALIDARG;

@@ -5474,7 +5486,7 @@
     REFIID riidResource,

     void** ppvResource)

 {

-    if(!pAllocation || !pResourceDesc || riidResource == IID_NULL || !ppvResource)

+    if(!pAllocation || !pResourceDesc || !ppvResource)

     {

         D3D12MA_ASSERT(0 && "Invalid arguments passed to Allocator::CreateAliasingResource.");

         return E_INVALIDARG;

diff --git a/src/D3D12MemAlloc.h b/src/D3D12MemAlloc.h
index a3d60fe..a8a82a8 100644
--- a/src/D3D12MemAlloc.h
+++ b/src/D3D12MemAlloc.h
@@ -1221,7 +1221,7 @@
     It is automatically destroyed when you destroy the allocation object.

 

     If 'ppvResource` is not null, you receive pointer to the resource next to allocation object.

-    Reference count of the resource object is then 2, so you need to manually `Release` it

+    Reference count of the resource object is then increased by calling `QueryInterface`, so you need to manually `Release` it

     along with the allocation.

 

     \param pAllocDesc   Parameters of the allocation.

@@ -1229,7 +1229,7 @@
     \param InitialResourceState   Initial resource state.

     \param pOptimizedClearValue   Optional. Either null or optimized clear value.

     \param[out] ppAllocation   Filled with pointer to new allocation object created.

-    \param riidResource   IID of a resource to be created. Must be `__uuidof(ID3D12Resource)`.

+    \param riidResource   IID of a resource to be returned via `ppvResource`.

     \param[out] ppvResource   Optional. If not null, filled with pointer to new resouce created.

     */

     HRESULT CreateResource(

diff --git a/src/Tests.cpp b/src/Tests.cpp
index 3152a94..d37aa28 100644
--- a/src/Tests.cpp
+++ b/src/Tests.cpp
@@ -524,6 +524,41 @@
     renderTargetRes.allocation.reset(alloc);

 }

 

+static void TestOtherComInterface(const TestContext& ctx)

+{

+    wprintf(L"Test other COM interface\n");

+

+    D3D12_RESOURCE_DESC resDesc;

+    FillResourceDescForBuffer(resDesc, 0x10000);

+

+    for(uint32_t i = 0; i < 2; ++i)

+    {

+        D3D12MA::ALLOCATION_DESC allocDesc = {};

+        allocDesc.HeapType = D3D12_HEAP_TYPE_DEFAULT;

+        if(i == 1)

+        {

+            allocDesc.Flags = D3D12MA::ALLOCATION_FLAG_COMMITTED;

+        }

+

+        D3D12MA::Allocation* alloc = nullptr;

+        CComPtr<ID3D12Pageable> pageable;

+        CHECK_HR(ctx.allocator->CreateResource(

+            &allocDesc,

+            &resDesc,

+            D3D12_RESOURCE_STATE_COMMON,

+            nullptr, // pOptimizedClearValue

+            &alloc,

+            IID_PPV_ARGS(&pageable)));

+

+        // Do something with the interface to make sure it's valid.

+        CComPtr<ID3D12Device> device;

+        CHECK_HR(pageable->GetDevice(IID_PPV_ARGS(&device)));

+        CHECK_BOOL(device == ctx.device);

+

+        alloc->Release();

+    }

+}

+

 static void TestCustomPools(const TestContext& ctx)

 {

     wprintf(L"Test custom pools\n");

@@ -1311,6 +1346,7 @@
     TestCommittedResourcesAndJson(ctx);

     TestCustomHeapFlags(ctx);

     TestPlacedResources(ctx);

+    TestOtherComInterface(ctx);

     TestCustomPools(ctx);

     TestDefaultPoolMinBytes(ctx);

     TestAliasingMemory(ctx);