blob: 30de0e62d83dc4557d9c0ec01735f04d37cb93db [file] [log] [blame]
#pragma once
//==============================================================================================
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//==============================================================================================
#include <iostream>
template <typename T> class SharedPtr {
public:
// Constructor
__host__ __device__ explicit SharedPtr(T *ptr = nullptr)
: ptr_(ptr), ref_count_(ptr ? new size_t(1) : nullptr) {}
// Destructor
__host__ __device__ ~SharedPtr() { release(); }
// Copy constructor
__host__ __device__ SharedPtr(const SharedPtr &other)
: ptr_(other.ptr_), ref_count_(other.ref_count_) {
if (ref_count_) {
(*ref_count_)++;
}
}
// Templated constructor for conversion from SharedPtr<U> to SharedPtr<T>
template <typename U>
__host__ __device__ SharedPtr(const SharedPtr<U> &other)
: ptr_(other.ptr_), ref_count_(other.ref_count_) {
static_assert(std::is_convertible<U *, T *>::value,
"SharedPtr<U> cannot be converted to SharedPtr<T>");
if (ref_count_) {
(*ref_count_)++;
}
}
// Copy assignment operator
__host__ __device__ SharedPtr &operator=(const SharedPtr &other) {
if (this != &other) {
release();
ptr_ = other.ptr_;
ref_count_ = other.ref_count_;
if (ref_count_) {
(*ref_count_)++;
}
}
return *this;
}
// Accessors
__host__ __device__ T *get() const { return ptr_; }
__host__ __device__ size_t use_count() const {
return ref_count_ ? *ref_count_ : 0;
}
__host__ __device__ T &operator*() const { return *ptr_; }
__host__ __device__ T *operator->() const { return ptr_; }
public:
T *ptr_; // Pointer to the managed object
size_t *ref_count_; // Reference count
__host__ __device__ void release() {
if (ref_count_ && --(*ref_count_) == 0) {
delete ptr_;
delete ref_count_;
}
}
};
template <typename T, typename... Args>
__host__ __device__ SharedPtr<T> makeShared(Args &&...args) {
return SharedPtr<T>(new T(std::forward<Args>(args)...));
}