layers: Refactor DESCRIPTOR_POOL_STATE

Move descriptor set lifecycle management into methods of this class.
diff --git a/layers/core_validation.cpp b/layers/core_validation.cpp
index b67b5cf..4682a9d 100644
--- a/layers/core_validation.cpp
+++ b/layers/core_validation.cpp
@@ -6304,9 +6304,10 @@
     // Make sure sets being destroyed are not currently in-use
     if (disabled[object_in_use]) return false;
     bool skip = false;
-    const DESCRIPTOR_POOL_STATE *pool = GetDescriptorPoolState(descriptorPool);
-    if (pool != nullptr) {
-        for (auto *ds : pool->sets) {
+    const auto pool = Get<DESCRIPTOR_POOL_STATE>(descriptorPool);
+    if (pool) {
+        for (const auto &entry : pool->sets) {
+            const auto *ds = entry.second;
             if (ds && ds->InUse()) {
                 skip |= LogError(descriptorPool, "VUID-vkResetDescriptorPool-descriptorPool-00313",
                                  "It is invalid to call vkResetDescriptorPool() with descriptor sets in use by a command buffer.");
diff --git a/layers/descriptor_sets.cpp b/layers/descriptor_sets.cpp
index 5fe359a..edd0ac2 100644
--- a/layers/descriptor_sets.cpp
+++ b/layers/descriptor_sets.cpp
@@ -34,6 +34,90 @@
 #include <array>
 #include <memory>
 
+static DESCRIPTOR_POOL_STATE::TypeCountMap GetMaxTypeCounts(const VkDescriptorPoolCreateInfo *create_info) {
+    DESCRIPTOR_POOL_STATE::TypeCountMap counts;
+    // Collect maximums per descriptor type.
+    for (uint32_t i = 0; i < create_info->poolSizeCount; ++i) {
+        const auto &pool_size = create_info->pPoolSizes[i];
+        uint32_t type = static_cast<uint32_t>(pool_size.type);
+        // Same descriptor types can appear several times
+        counts[type] += pool_size.descriptorCount;
+    }
+    return counts;
+}
+
+DESCRIPTOR_POOL_STATE::DESCRIPTOR_POOL_STATE(ValidationStateTracker *dev, const VkDescriptorPool pool,
+                                             const VkDescriptorPoolCreateInfo *pCreateInfo)
+    : BASE_NODE(pool, kVulkanObjectTypeDescriptorPool),
+      dev_data(dev),
+      maxSets(pCreateInfo->maxSets),
+      availableSets(pCreateInfo->maxSets),
+      createInfo(pCreateInfo),
+      maxDescriptorTypeCount(GetMaxTypeCounts(pCreateInfo)),
+      availableDescriptorTypeCount(maxDescriptorTypeCount) {}
+
+void DESCRIPTOR_POOL_STATE::Allocate(const VkDescriptorSetAllocateInfo *alloc_info, const VkDescriptorSet *descriptor_sets,
+                                     const cvdescriptorset::AllocateDescriptorSetsData *ds_data) {
+    // Account for sets and individual descriptors allocated from pool
+    availableSets -= alloc_info->descriptorSetCount;
+    for (auto it = ds_data->required_descriptors_by_type.begin(); it != ds_data->required_descriptors_by_type.end(); ++it) {
+        availableDescriptorTypeCount[it->first] -= ds_data->required_descriptors_by_type.at(it->first);
+    }
+
+    const auto *variable_count_info = LvlFindInChain<VkDescriptorSetVariableDescriptorCountAllocateInfo>(alloc_info->pNext);
+    bool variable_count_valid = variable_count_info && variable_count_info->descriptorSetCount == alloc_info->descriptorSetCount;
+
+    // Create tracking object for each descriptor set; insert into global map and the pool's set.
+    for (uint32_t i = 0; i < alloc_info->descriptorSetCount; i++) {
+        uint32_t variable_count = variable_count_valid ? variable_count_info->pDescriptorCounts[i] : 0;
+
+        auto new_ds = std::make_shared<cvdescriptorset::DescriptorSet>(descriptor_sets[i], this, ds_data->layout_nodes[i],
+                                                                       variable_count, dev_data);
+        sets.emplace(descriptor_sets[i], new_ds.get());
+        dev_data->setMap.emplace(descriptor_sets[i], std::move(new_ds));
+    }
+}
+
+void DESCRIPTOR_POOL_STATE::Free(uint32_t count, const VkDescriptorSet *descriptor_sets) {
+    // Update available descriptor sets in pool
+    availableSets += count;
+
+    // For each freed descriptor add its resources back into the pool as available and remove from pool and setMap
+    for (uint32_t i = 0; i < count; ++i) {
+        if (descriptor_sets[i] != VK_NULL_HANDLE) {
+            auto iter = sets.find(descriptor_sets[i]);
+            assert(iter != sets.end());
+            auto *set_state = iter->second;
+            uint32_t type_index = 0, descriptor_count = 0;
+            for (uint32_t j = 0; j < set_state->GetBindingCount(); ++j) {
+                type_index = static_cast<uint32_t>(set_state->GetTypeFromIndex(j));
+                descriptor_count = set_state->GetDescriptorCountFromIndex(j);
+                availableDescriptorTypeCount[type_index] += descriptor_count;
+            }
+            set_state->Destroy();
+            dev_data->setMap.erase(iter->first);
+            sets.erase(iter);
+        }
+    }
+}
+
+void DESCRIPTOR_POOL_STATE::Reset() {
+    // For every set off of this pool, clear it, remove from setMap, and free cvdescriptorset::DescriptorSet
+    for (auto entry : sets) {
+        entry.second->Destroy();
+        dev_data->setMap.erase(entry.first);
+    }
+    sets.clear();
+    // Reset available count for each type and available sets for this pool
+    availableDescriptorTypeCount = maxDescriptorTypeCount;
+    availableSets = maxSets;
+}
+
+void DESCRIPTOR_POOL_STATE::Destroy() {
+    Reset();
+    BASE_NODE::Destroy();
+}
+
 // ExtendedBinding collects a VkDescriptorSetLayoutBinding and any extended
 // state that comes from a different array/structure so they can stay together
 // while being sorted by binding number.
@@ -3206,10 +3290,10 @@
     if (!IsExtEnabled(device_extensions.vk_khr_maintenance1)) {
         // Track number of descriptorSets allowable in this pool
         if (pool_state->availableSets < p_alloc_info->descriptorSetCount) {
-            skip |= LogError(pool_state->pool, "VUID-VkDescriptorSetAllocateInfo-descriptorSetCount-00306",
+            skip |= LogError(pool_state->Handle(), "VUID-VkDescriptorSetAllocateInfo-descriptorSetCount-00306",
                              "vkAllocateDescriptorSets(): Unable to allocate %u descriptorSets from %s"
                              ". This pool only has %d descriptorSets remaining.",
-                             p_alloc_info->descriptorSetCount, report_data->FormatHandle(pool_state->pool).c_str(),
+                             p_alloc_info->descriptorSetCount, report_data->FormatHandle(pool_state->Handle()).c_str(),
                              pool_state->availableSets);
         }
         // Determine whether descriptor counts are satisfiable
@@ -3218,12 +3302,12 @@
             uint32_t available_count = (count_iter != pool_state->availableDescriptorTypeCount.end()) ? count_iter->second : 0;
 
             if (ds_data->required_descriptors_by_type.at(it->first) > available_count) {
-                skip |= LogError(pool_state->pool, "VUID-VkDescriptorSetAllocateInfo-descriptorPool-00307",
+                skip |= LogError(pool_state->Handle(), "VUID-VkDescriptorSetAllocateInfo-descriptorPool-00307",
                                  "vkAllocateDescriptorSets(): Unable to allocate %u descriptors of type %s from %s"
                                  ". This pool only has %d descriptors of this type remaining.",
                                  ds_data->required_descriptors_by_type.at(it->first),
                                  string_VkDescriptorType(VkDescriptorType(it->first)),
-                                 report_data->FormatHandle(pool_state->pool).c_str(), available_count);
+                                 report_data->FormatHandle(pool_state->Handle()).c_str(), available_count);
             }
         }
     }
diff --git a/layers/descriptor_sets.h b/layers/descriptor_sets.h
index 5902951..370380e 100644
--- a/layers/descriptor_sets.h
+++ b/layers/descriptor_sets.h
@@ -49,35 +49,29 @@
 
 namespace cvdescriptorset {
 class DescriptorSet;
+struct AllocateDescriptorSetsData;
 }
 
 class DESCRIPTOR_POOL_STATE : public BASE_NODE {
   public:
-    VkDescriptorPool pool;
-    uint32_t maxSets;        // Max descriptor sets allowed in this pool
+    ValidationStateTracker *dev_data;
+    const uint32_t maxSets;  // Max descriptor sets allowed in this pool
     uint32_t availableSets;  // Available descriptor sets in this pool
 
-    safe_VkDescriptorPoolCreateInfo createInfo;
-    layer_data::unordered_set<cvdescriptorset::DescriptorSet *> sets;  // Collection of all sets in this pool
-    std::map<uint32_t, uint32_t> maxDescriptorTypeCount;               // Max # of descriptors of each type in this pool
-    std::map<uint32_t, uint32_t> availableDescriptorTypeCount;         // Available # of descriptors of each type in this pool
+    const safe_VkDescriptorPoolCreateInfo createInfo;
+    using TypeCountMap = layer_data::unordered_map<uint32_t, uint32_t>;
+    const TypeCountMap maxDescriptorTypeCount;  // Max # of descriptors of each type in this pool
+    TypeCountMap availableDescriptorTypeCount;  // Available # of descriptors of each type in this pool
+    layer_data::unordered_map<VkDescriptorSet, cvdescriptorset::DescriptorSet *> sets;  // Collection of all sets in this pool
 
-    DESCRIPTOR_POOL_STATE(const VkDescriptorPool pool, const VkDescriptorPoolCreateInfo *pCreateInfo)
-        : BASE_NODE(pool, kVulkanObjectTypeDescriptorPool),
-          pool(pool),
-          maxSets(pCreateInfo->maxSets),
-          availableSets(pCreateInfo->maxSets),
-          createInfo(pCreateInfo),
-          maxDescriptorTypeCount(),
-          availableDescriptorTypeCount() {
-        // Collect maximums per descriptor type.
-        for (uint32_t i = 0; i < createInfo.poolSizeCount; ++i) {
-            uint32_t typeIndex = static_cast<uint32_t>(createInfo.pPoolSizes[i].type);
-            // Same descriptor types can appear several times
-            maxDescriptorTypeCount[typeIndex] += createInfo.pPoolSizes[i].descriptorCount;
-            availableDescriptorTypeCount[typeIndex] = maxDescriptorTypeCount[typeIndex];
-        }
-    }
+    DESCRIPTOR_POOL_STATE(ValidationStateTracker *dev, const VkDescriptorPool pool, const VkDescriptorPoolCreateInfo *pCreateInfo);
+    ~DESCRIPTOR_POOL_STATE() { Destroy(); }
+
+    void Allocate(const VkDescriptorSetAllocateInfo *alloc_info, const VkDescriptorSet *descriptor_sets,
+                  const cvdescriptorset::AllocateDescriptorSetsData *ds_data);
+    void Free(uint32_t count, const VkDescriptorSet *descriptor_sets);
+    void Reset();
+    void Destroy() override;
 };
 
 // Descriptor Data structures
diff --git a/layers/state_tracker.cpp b/layers/state_tracker.cpp
index 70e46de..ec43ffb 100644
--- a/layers/state_tracker.cpp
+++ b/layers/state_tracker.cpp
@@ -445,27 +445,6 @@
     return GetObjectMemBindingImpl<ValidationStateTracker *, BINDABLE *>(this, typed_handle);
 }
 
-// Remove set from setMap and delete the set
-void ValidationStateTracker::FreeDescriptorSet(cvdescriptorset::DescriptorSet *descriptor_set) {
-    // Any bound cmd buffers are now invalid
-    descriptor_set->Destroy();
-
-    setMap.erase(descriptor_set->GetSet());
-}
-
-// Free all DS Pools including their Sets & related sub-structs
-// NOTE : Calls to this function should be wrapped in mutex
-void ValidationStateTracker::DeleteDescriptorSetPools() {
-    for (auto ii = descriptorPoolMap.begin(); ii != descriptorPoolMap.end();) {
-        // Remove this pools' sets from setMap and delete them
-        for (auto *ds : ii->second->sets) {
-            FreeDescriptorSet(ds);
-        }
-        ii->second->sets.clear();
-        ii = descriptorPoolMap.erase(ii);
-    }
-}
-
 // For given object struct return a ptr of BASE_NODE type for its wrapping struct
 BASE_NODE *ValidationStateTracker::GetStateStructPtrFromObject(const VulkanTypedHandle &object_struct) {
     if (object_struct.node) {
@@ -1357,7 +1336,7 @@
     renderPassMap.clear();
 
     // This will also delete all sets in the pool & remove them from setMap
-    DeleteDescriptorSetPools();
+    descriptorPoolMap.clear();
     // All sets should be removed
     assert(setMap.empty());
     descriptorSetLayoutMap.clear();
@@ -1889,13 +1868,8 @@
 
 void ValidationStateTracker::PreCallRecordDestroyDescriptorPool(VkDevice device, VkDescriptorPool descriptorPool,
                                                                 const VkAllocationCallbacks *pAllocator) {
-    if (!descriptorPool) return;
-    DESCRIPTOR_POOL_STATE *desc_pool_state = GetDescriptorPoolState(descriptorPool);
+    auto *desc_pool_state = Get<DESCRIPTOR_POOL_STATE>(descriptorPool);
     if (desc_pool_state) {
-        // Free sets that were in this pool
-        for (auto *ds : desc_pool_state->sets) {
-            FreeDescriptorSet(ds);
-        }
         desc_pool_state->Destroy();
         descriptorPoolMap.erase(descriptorPool);
     }
@@ -2167,24 +2141,16 @@
                                                                 const VkAllocationCallbacks *pAllocator,
                                                                 VkDescriptorPool *pDescriptorPool, VkResult result) {
     if (VK_SUCCESS != result) return;
-    descriptorPoolMap[*pDescriptorPool] = std::make_shared<DESCRIPTOR_POOL_STATE>(*pDescriptorPool, pCreateInfo);
+    descriptorPoolMap.emplace(*pDescriptorPool, std::make_shared<DESCRIPTOR_POOL_STATE>(this, *pDescriptorPool, pCreateInfo));
 }
 
 void ValidationStateTracker::PostCallRecordResetDescriptorPool(VkDevice device, VkDescriptorPool descriptorPool,
                                                                VkDescriptorPoolResetFlags flags, VkResult result) {
     if (VK_SUCCESS != result) return;
-    DESCRIPTOR_POOL_STATE *pool = GetDescriptorPoolState(descriptorPool);
-    // TODO: validate flags
-    // For every set off of this pool, clear it, remove from setMap, and free cvdescriptorset::DescriptorSet
-    for (auto *ds : pool->sets) {
-        FreeDescriptorSet(ds);
+    auto pool = Get<DESCRIPTOR_POOL_STATE>(descriptorPool);
+    if (pool) {
+        pool->Reset();
     }
-    pool->sets.clear();
-    // Reset available count for each type and available sets for this pool
-    for (auto it = pool->availableDescriptorTypeCount.begin(); it != pool->availableDescriptorTypeCount.end(); ++it) {
-        pool->availableDescriptorTypeCount[it->first] = pool->maxDescriptorTypeCount[it->first];
-    }
-    pool->availableSets = pool->maxSets;
 }
 
 bool ValidationStateTracker::PreCallValidateAllocateDescriptorSets(VkDevice device,
@@ -2206,28 +2172,17 @@
     // All the updates are contained in a single cvdescriptorset function
     cvdescriptorset::AllocateDescriptorSetsData *ads_state =
         reinterpret_cast<cvdescriptorset::AllocateDescriptorSetsData *>(ads_state_data);
-    PerformAllocateDescriptorSets(pAllocateInfo, pDescriptorSets, ads_state);
+    auto pool_state = Get<DESCRIPTOR_POOL_STATE>(pAllocateInfo->descriptorPool);
+    if (pool_state) {
+        pool_state->Allocate(pAllocateInfo, pDescriptorSets, ads_state);
+    }
 }
 
 void ValidationStateTracker::PreCallRecordFreeDescriptorSets(VkDevice device, VkDescriptorPool descriptorPool, uint32_t count,
                                                              const VkDescriptorSet *pDescriptorSets) {
-    DESCRIPTOR_POOL_STATE *pool_state = GetDescriptorPoolState(descriptorPool);
-    // Update available descriptor sets in pool
-    pool_state->availableSets += count;
-
-    // For each freed descriptor add its resources back into the pool as available and remove from pool and setMap
-    for (uint32_t i = 0; i < count; ++i) {
-        if (pDescriptorSets[i] != VK_NULL_HANDLE) {
-            auto descriptor_set = setMap[pDescriptorSets[i]].get();
-            uint32_t type_index = 0, descriptor_count = 0;
-            for (uint32_t j = 0; j < descriptor_set->GetBindingCount(); ++j) {
-                type_index = static_cast<uint32_t>(descriptor_set->GetTypeFromIndex(j));
-                descriptor_count = descriptor_set->GetDescriptorCountFromIndex(j);
-                pool_state->availableDescriptorTypeCount[type_index] += descriptor_count;
-            }
-            FreeDescriptorSet(descriptor_set);
-            pool_state->sets.erase(descriptor_set);
-        }
+    auto pool_state = Get<DESCRIPTOR_POOL_STATE>(descriptorPool);
+    if (pool_state) {
+        pool_state->Free(count, pDescriptorSets);
     }
 }
 
@@ -3926,31 +3881,6 @@
     }
 }
 
-// Decrement allocated sets from the pool and insert new sets into set_map
-void ValidationStateTracker::PerformAllocateDescriptorSets(const VkDescriptorSetAllocateInfo *p_alloc_info,
-                                                           const VkDescriptorSet *descriptor_sets,
-                                                           const cvdescriptorset::AllocateDescriptorSetsData *ds_data) {
-    auto pool_state = descriptorPoolMap[p_alloc_info->descriptorPool].get();
-    // Account for sets and individual descriptors allocated from pool
-    pool_state->availableSets -= p_alloc_info->descriptorSetCount;
-    for (auto it = ds_data->required_descriptors_by_type.begin(); it != ds_data->required_descriptors_by_type.end(); ++it) {
-        pool_state->availableDescriptorTypeCount[it->first] -= ds_data->required_descriptors_by_type.at(it->first);
-    }
-
-    const auto *variable_count_info = LvlFindInChain<VkDescriptorSetVariableDescriptorCountAllocateInfo>(p_alloc_info->pNext);
-    bool variable_count_valid = variable_count_info && variable_count_info->descriptorSetCount == p_alloc_info->descriptorSetCount;
-
-    // Create tracking object for each descriptor set; insert into global map and the pool's set.
-    for (uint32_t i = 0; i < p_alloc_info->descriptorSetCount; i++) {
-        uint32_t variable_count = variable_count_valid ? variable_count_info->pDescriptorCounts[i] : 0;
-
-        auto new_ds = std::make_shared<cvdescriptorset::DescriptorSet>(descriptor_sets[i], pool_state, ds_data->layout_nodes[i],
-                                                                       variable_count, this);
-        pool_state->sets.insert(new_ds.get());
-        setMap[descriptor_sets[i]] = std::move(new_ds);
-    }
-}
-
 void ValidationStateTracker::PostCallRecordCmdDraw(VkCommandBuffer commandBuffer, uint32_t vertexCount, uint32_t instanceCount,
                                                    uint32_t firstVertex, uint32_t firstInstance) {
     CMD_BUFFER_STATE *cb_state = Get<CMD_BUFFER_STATE>(commandBuffer);
diff --git a/layers/state_tracker.h b/layers/state_tracker.h
index 2bb12c1..eafc508 100644
--- a/layers/state_tracker.h
+++ b/layers/state_tracker.h
@@ -1137,15 +1137,11 @@
                                                 VkResult result) override;
 
     // State Utilty functions
-    void DeleteDescriptorSetPools();
-    void FreeDescriptorSet(cvdescriptorset::DescriptorSet* descriptor_set);
     std::vector<std::shared_ptr<const IMAGE_VIEW_STATE>> GetSharedAttachmentViews(const VkRenderPassBeginInfo& rp_begin,
                                                                                   const FRAMEBUFFER_STATE& fb_state) const;
 
     BASE_NODE* GetStateStructPtrFromObject(const VulkanTypedHandle& object_struct);
     VkFormatFeatureFlags GetPotentialFormatFeatures(VkFormat format) const;
-    void PerformAllocateDescriptorSets(const VkDescriptorSetAllocateInfo*, const VkDescriptorSet*,
-                                       const cvdescriptorset::AllocateDescriptorSetsData*);
     void PerformUpdateDescriptorSetsWithTemplateKHR(VkDescriptorSet descriptorSet, const TEMPLATE_STATE* template_state,
                                                     const void* pData);
     void RecordAcquireNextImageState(VkDevice device, VkSwapchainKHR swapchain, uint64_t timeout, VkSemaphore semaphore,