diff --git a/src/D3D12MemAlloc.cpp b/src/D3D12MemAlloc.cpp
index 808efcc..bf8b43d 100644
--- a/src/D3D12MemAlloc.cpp
+++ b/src/D3D12MemAlloc.cpp
@@ -29,6 +29,7 @@
     #endif
 #endif
 
+#include <combaseapi.h>
 #include <mutex>
 #include <atomic>
 #include <algorithm>
@@ -3634,15 +3635,21 @@
             &resourceDesc,
             InitialResourceState,
             pOptimizedClearValue,
-            riidResource,
-            (void**)&res);
+            IID_PPV_ARGS(&res));
         if(SUCCEEDED(hr))
         {
-            (*ppAllocation)->SetResource(res, &resourceDesc);
             if(ppvResource != NULL)
             {
-                res->AddRef();
-                *ppvResource = res;
+                hr = res->QueryInterface(riidResource, ppvResource);
+            }
+            if(SUCCEEDED(hr))
+            {
+                (*ppAllocation)->SetResource(res, &resourceDesc);
+            }
+            else
+            {
+                res->Release();
+                SAFE_RELEASE(*ppAllocation);
             }
         }
         else
@@ -4515,31 +4522,37 @@
     D3D12_HEAP_PROPERTIES heapProps = {};
     heapProps.Type = pAllocDesc->HeapType;
 
-    D3D12_HEAP_FLAGS heapFlags = pAllocDesc->ExtraHeapFlags;
+    const D3D12_HEAP_FLAGS heapFlags = pAllocDesc->ExtraHeapFlags;
 
     ID3D12Resource* res = NULL;
     HRESULT hr = m_Device->CreateCommittedResource(
         &heapProps, heapFlags, pResourceDesc, InitialResourceState,
-        pOptimizedClearValue, riidResource, (void**)&res);
+        pOptimizedClearValue, IID_PPV_ARGS(&res));
     if(SUCCEEDED(hr))
     {
-        const BOOL wasZeroInitialized = TRUE;
-        Allocation* alloc = m_AllocationObjectAllocator.Allocate(this, resAllocInfo.SizeInBytes, wasZeroInitialized);
-        alloc->InitCommitted(pAllocDesc->HeapType);
-        alloc->SetResource(res, pResourceDesc);
-
-        *ppAllocation = alloc;
         if(ppvResource != NULL)
         {
-            res->AddRef();
-            *ppvResource = res;
+            hr = res->QueryInterface(riidResource, ppvResource);
         }
+        if(SUCCEEDED(hr))
+        {
+            const BOOL wasZeroInitialized = TRUE;
+            Allocation* alloc = m_AllocationObjectAllocator.Allocate(this, resAllocInfo.SizeInBytes, wasZeroInitialized);
+            alloc->InitCommitted(pAllocDesc->HeapType);
+            alloc->SetResource(res, pResourceDesc);
 
-        RegisterCommittedAllocation(*ppAllocation, pAllocDesc->HeapType);
+            *ppAllocation = alloc;
 
-        const UINT heapTypeIndex = HeapTypeToIndex(pAllocDesc->HeapType);
-        m_Budget.AddAllocation(heapTypeIndex, resAllocInfo.SizeInBytes);
-        m_Budget.m_BlockBytes[heapTypeIndex] += resAllocInfo.SizeInBytes;
+            RegisterCommittedAllocation(*ppAllocation, pAllocDesc->HeapType);
+
+            const UINT heapTypeIndex = HeapTypeToIndex(pAllocDesc->HeapType);
+            m_Budget.AddAllocation(heapTypeIndex, resAllocInfo.SizeInBytes);
+            m_Budget.m_BlockBytes[heapTypeIndex] += resAllocInfo.SizeInBytes;
+        }
+        else
+        {
+            res->Release();
+        }
     }
     return hr;
 }
@@ -5359,8 +5372,7 @@
 
 void Allocation::SetResource(ID3D12Resource* resource, const D3D12_RESOURCE_DESC* pResourceDesc)
 {
-    D3D12MA_ASSERT(m_Resource == NULL);
-    D3D12MA_ASSERT(pResourceDesc);
+    D3D12MA_ASSERT(m_Resource == NULL && pResourceDesc);
     m_Resource = resource;
     m_PackedData.SetResourceDimension(pResourceDesc->Dimension);
     m_PackedData.SetResourceFlags(pResourceDesc->Flags);
@@ -5435,7 +5447,7 @@
     REFIID riidResource,
     void** ppvResource)
 {
-    if(!pAllocDesc || !pResourceDesc || !ppAllocation || riidResource == IID_NULL)
+    if(!pAllocDesc || !pResourceDesc || !ppAllocation)
     {
         D3D12MA_ASSERT(0 && "Invalid arguments passed to Allocator::CreateResource.");
         return E_INVALIDARG;
@@ -5474,7 +5486,7 @@
     REFIID riidResource,
     void** ppvResource)
 {
-    if(!pAllocation || !pResourceDesc || riidResource == IID_NULL || !ppvResource)
+    if(!pAllocation || !pResourceDesc || !ppvResource)
     {
         D3D12MA_ASSERT(0 && "Invalid arguments passed to Allocator::CreateAliasingResource.");
         return E_INVALIDARG;
diff --git a/src/D3D12MemAlloc.h b/src/D3D12MemAlloc.h
index a3d60fe..a8a82a8 100644
--- a/src/D3D12MemAlloc.h
+++ b/src/D3D12MemAlloc.h
@@ -1221,7 +1221,7 @@
     It is automatically destroyed when you destroy the allocation object.
 
     If 'ppvResource` is not null, you receive pointer to the resource next to allocation object.
-    Reference count of the resource object is then 2, so you need to manually `Release` it
+    Reference count of the resource object is then increased by calling `QueryInterface`, so you need to manually `Release` it
     along with the allocation.
 
     \param pAllocDesc   Parameters of the allocation.
@@ -1229,7 +1229,7 @@
     \param InitialResourceState   Initial resource state.
     \param pOptimizedClearValue   Optional. Either null or optimized clear value.
     \param[out] ppAllocation   Filled with pointer to new allocation object created.
-    \param riidResource   IID of a resource to be created. Must be `__uuidof(ID3D12Resource)`.
+    \param riidResource   IID of a resource to be returned via `ppvResource`.
     \param[out] ppvResource   Optional. If not null, filled with pointer to new resouce created.
     */
     HRESULT CreateResource(
diff --git a/src/Tests.cpp b/src/Tests.cpp
index 3152a94..d37aa28 100644
--- a/src/Tests.cpp
+++ b/src/Tests.cpp
@@ -524,6 +524,41 @@
     renderTargetRes.allocation.reset(alloc);
 }
 
+static void TestOtherComInterface(const TestContext& ctx)
+{
+    wprintf(L"Test other COM interface\n");
+
+    D3D12_RESOURCE_DESC resDesc;
+    FillResourceDescForBuffer(resDesc, 0x10000);
+
+    for(uint32_t i = 0; i < 2; ++i)
+    {
+        D3D12MA::ALLOCATION_DESC allocDesc = {};
+        allocDesc.HeapType = D3D12_HEAP_TYPE_DEFAULT;
+        if(i == 1)
+        {
+            allocDesc.Flags = D3D12MA::ALLOCATION_FLAG_COMMITTED;
+        }
+
+        D3D12MA::Allocation* alloc = nullptr;
+        CComPtr<ID3D12Pageable> pageable;
+        CHECK_HR(ctx.allocator->CreateResource(
+            &allocDesc,
+            &resDesc,
+            D3D12_RESOURCE_STATE_COMMON,
+            nullptr, // pOptimizedClearValue
+            &alloc,
+            IID_PPV_ARGS(&pageable)));
+
+        // Do something with the interface to make sure it's valid.
+        CComPtr<ID3D12Device> device;
+        CHECK_HR(pageable->GetDevice(IID_PPV_ARGS(&device)));
+        CHECK_BOOL(device == ctx.device);
+
+        alloc->Release();
+    }
+}
+
 static void TestCustomPools(const TestContext& ctx)
 {
     wprintf(L"Test custom pools\n");
@@ -1311,6 +1346,7 @@
     TestCommittedResourcesAndJson(ctx);
     TestCustomHeapFlags(ctx);
     TestPlacedResources(ctx);
+    TestOtherComInterface(ctx);
     TestCustomPools(ctx);
     TestDefaultPoolMinBytes(ctx);
     TestAliasingMemory(ctx);
