diff --git a/smart-ptr/Makefile b/smart-ptr/Makefile new file mode 100644 index 0000000..cadca4f --- /dev/null +++ b/smart-ptr/Makefile @@ -0,0 +1,18 @@ +CXX = g++ +CXXFLAGS = -std=c++23 -Wall -Wextra -g -I. + +TARGET = main +SRCS = main.cpp + +all: $(TARGET) + +$(TARGET): $(SRCS) + $(CXX) $(CXXFLAGS) -o $(TARGET) $(SRCS) + +run: $(TARGET) + ./$(TARGET) $(ARGS) + +clean: + rm -f $(TARGET) + +.PHONY: all clean diff --git a/smart-ptr/cpp_utility.h b/smart-ptr/cpp_utility.h new file mode 100644 index 0000000..0d51aea --- /dev/null +++ b/smart-ptr/cpp_utility.h @@ -0,0 +1,45 @@ +#ifndef CPP_UTILITY_H +#define CPP_UTILITY_H + +namespace utility { + +// remove_reference +template struct remove_reference +{ + using type = T; +}; +template struct remove_reference +{ + using type = T; +}; +template struct remove_reference +{ + using type = T; +}; + +// move +template +constexpr typename remove_reference::type&& +move(T&& t) noexcept +{ + return static_cast::type&&>(t); +} + +// forward (optional, for perfect forwarding) +template +constexpr T&& +forward(typename remove_reference::type& t) noexcept +{ + return static_cast(t); +} + +template +constexpr T&& +forward(typename remove_reference::type&& t) noexcept +{ + return static_cast(t); +} + +} // namespace utility + +#endif // CPP_UTILITY_H \ No newline at end of file diff --git a/smart-ptr/main.cpp b/smart-ptr/main.cpp new file mode 100644 index 0000000..ab161c1 --- /dev/null +++ b/smart-ptr/main.cpp @@ -0,0 +1,60 @@ +#include +#include +#include "smart_ptr.h" + +struct Test +{ + int x; + Test(int v) : x(v) { std::cout << "Test(" << x << ") constructed\n"; } + ~Test() { std::cout << "Test(" << x << ") destroyed\n"; } +}; + +int +main() +{ + std::cout << "--- Testing SharedPtr ---\\n"; + { + SharedPtr sp1(new Test(10)); + assert(sp1->x == 10); + assert(sp1.use_count() == 1); + + { + SharedPtr sp2 = sp1; + assert(sp2->x == 10); + assert(sp1.use_count() == 2); + assert(sp2.use_count() == 2); + } + assert(sp1.use_count() == 1); + } + std::cout << "SharedPtr test passed.\n"; + + std::cout << "--- Testing MakeShared ---\\n"; + { + SharedPtr sp = MakeShared(20); + assert(sp->x == 20); + assert(sp.use_count() == 1); + } + std::cout << "MakeShared test passed.\n"; + + std::cout << "--- Testing WeakPtr ---\\n"; + { + SharedPtr sp = MakeShared(30); + WeakPtr wp = sp; + assert(!wp.expired()); + + SharedPtr sp2 = wp.lock(); + assert(sp2); + assert(sp2->x == 30); + assert(sp.use_count() == 2); + + sp2 = nullptr; // release one reference + assert(sp.use_count() == 1); + + sp = nullptr; // release last reference + assert(wp.expired()); + assert(wp.lock().get() == nullptr); + } + std::cout << "WeakPtr test passed.\n"; + + return 0; +} diff --git a/smart-ptr/ref_block_base.h b/smart-ptr/ref_block_base.h new file mode 100644 index 0000000..e754f41 --- /dev/null +++ b/smart-ptr/ref_block_base.h @@ -0,0 +1,128 @@ +#ifndef REF_BLOCK_BASE_H +#define REF_BLOCK_BASE_H + +#include +#include +#include +#include + +#include "cpp_utility.h" + +struct RefBlockBase +{ + // atomic operation + std::atomic m_shared_count{1}; // when init, create the shared_ptr and own it + std::atomic m_weak_count{1}; // when init, m_shared_count as one "weak reference" + + virtual ~RefBlockBase() = default; + + // type erasure + virtual void + dispose_resource() = 0; + virtual void + destroy_self() = 0; + virtual void* + get_resource_ptr() + { + return nullptr; + } + + // ---- thread safe counter operation ---- + void + increment_shared() noexcept + { + m_shared_count.fetch_add(1, std::memory_order_relaxed); + } + + void + increment_weak() noexcept + { + m_weak_count.fetch_add(1, std::memory_order_relaxed); + } + + void + decrement_shared() noexcept + { + // fetch_sub return the value before minus + // use acq_rel (Acquire-Release) ensure memory safe + if (m_shared_count.fetch_sub(1, std::memory_order_acq_rel) == 1) { + dispose_resource(); + decrement_weak(); + } + } + + void + decrement_weak() noexcept + { + if (m_weak_count.fetch_sub(1, std::memory_order_acq_rel) == 1) { + // weak counter oges 0, destroy control block + destroy_self(); + } + } + + bool + try_increment_shared() noexcept + { + size_t count = m_shared_count.load(std::memory_order_relaxed); + + while (count != 0) { + // try to replace count with count + 1 + if (m_shared_count.compare_exchange_weak(count, count + 1, std::memory_order_acq_rel)) { + return true; // success + } + } + return false; + } +}; + +// for 'new' +// Y is the actual type, D is the del type +template struct RefBlockImpl : public RefBlockBase +{ + Y* m_resource; + D m_deleter; + + RefBlockImpl(Y* res, D del) : m_resource(res), m_deleter(utility::move(del)) {} + + void + dispose_resource() override + { + // call the deleter + m_deleter(m_resource); + } + + void + destroy_self() override + { + // destroy self + delete this; + } +}; + +template struct RefBlockMakeShared : public RefBlockBase +{ + // T's data will followed directly after this struct + // use an aligned char array for padding + alignas(T) char m_storage[sizeof(T)]; + + void* + get_resource_ptr() override + { + return reinterpret_cast(m_storage); + } + + void + dispose_resource() override + { + // call the deconstruct but not release the memory + reinterpret_cast(m_storage)->~T(); + } + + void + destroy_self() override + { + delete this; + } +}; + +#endif diff --git a/smart-ptr/smart_ptr.h b/smart-ptr/smart_ptr.h new file mode 100644 index 0000000..df25c14 --- /dev/null +++ b/smart-ptr/smart_ptr.h @@ -0,0 +1,240 @@ +#ifndef SMART_PTR_H +#define SMART_PTR_H + +#include +#include "cpp_utility.h" +#include "ref_block_base.h" + +// forwared declaration +template class SharedPtr; + +template class WeakPtr; + +template +SharedPtr +MakeShared(Args&&... args); + +template class SharedPtr +{ + // allow weaktptr to access private + friend class WeakPtr; + template + friend SharedPtr + MakeShared(Args&&... args); + + private: + T* m_ptr = nullptr; + + RefBlockBase* m_block = nullptr; + + // private construck for weakptr lock and MakeShared + SharedPtr(T* ptr, RefBlockBase* block) noexcept : m_ptr(ptr), m_block(block) {} + + public: + // --- construct func --- + + // default + SharedPtr() noexcept = default; + + // accept rao ptr + explicit SharedPtr(T* ptr) : SharedPtr(ptr, std::default_delete()) {} + + // construct from nullptr + SharedPtr(std::nullptr_t) noexcept : SharedPtr() {} + + // core construct with deleter + template < + typename Y, + typename D, + typename = + typename std::enable_if::type, RefBlockBase*>::value>::type> + SharedPtr(Y* ptr, D&& deleter) : m_ptr(ptr) + { + using DeleterType = typename std::decay::type; + DeleterType deleter_copy = utility::forward(deleter); + try { + // try to allocate for control block + m_block = new RefBlockImpl(ptr, deleter_copy); + } catch (...) { + // if new refblockimpl fails, delete ptr and through error + deleter_copy(ptr); + throw; + } + } + + // move constructor + SharedPtr(SharedPtr&& other) noexcept : m_ptr(other.m_ptr), m_block(other.m_block) + { + other.m_ptr = nullptr; + other.m_block = nullptr; + } + + // --- deconstructor --- + ~SharedPtr() noexcept + { + if (m_block) { + m_block->decrement_shared(); + } + } + + // --- copy control --- + SharedPtr(const SharedPtr& other) : m_ptr(other.m_ptr), m_block(other.m_block) + { + if (m_block) { + m_block->increment_shared(); + } + } + + SharedPtr& + operator=(const SharedPtr& other) + { + if (this != &other) { + // 1. release old resource + if (m_block) { + m_block->decrement_shared(); + } + + // 2. copy new resource + m_ptr = other.m_ptr; + m_block = other.m_block; + + if (m_block) { + m_block->increment_shared(); + } + } + return *this; + } + + // move assignment operator + SharedPtr& + operator=(SharedPtr&& other) noexcept + { + if (this != &other) { + // 1. Release current resources + if (m_block) { + m_block->decrement_shared(); + } + + // 2. Take ownership from other + m_ptr = other.m_ptr; + m_block = other.m_block; + + // 3. Null other's resources + other.m_ptr = nullptr; + other.m_block = nullptr; + } + return *this; + } + + SharedPtr& + operator=(std::nullptr_t) + { + if (m_block) { + m_block->decrement_shared(); + m_block = nullptr; + m_ptr = nullptr; + } + return *this; + } + + T* + get() const noexcept + { + return m_ptr; + } + + T& + operator*() const noexcept + { + return *m_ptr; + } + T* + operator->() const noexcept + { + return m_ptr; + } + + size_t + use_count() const noexcept + { + return m_block ? m_block->m_shared_count.load() : 0; + } + + explicit + operator bool() const noexcept + { + return m_ptr != nullptr; + } +}; + +template class WeakPtr +{ + private: + // weak ptr can not safely holding T*, cause T can be destroyed any time + RefBlockBase* m_block = nullptr; + + public: + // allow shared ptr to access + friend class SharedPtr; + + WeakPtr() noexcept = default; + + WeakPtr(const SharedPtr& shared) noexcept : m_block(shared.m_block) + { + if (m_block) { + m_block->increment_weak(); + } + } + + ~WeakPtr() noexcept + { + if (m_block) { + m_block->decrement_weak(); + } + } + + bool + expired() const noexcept + { + return !m_block || m_block->m_shared_count.load(std::memory_order_acquire) == 0; + } + + SharedPtr + lock() const noexcept + { + if (expired()) { + return SharedPtr(); + } + if (m_block->try_increment_shared()) { + return SharedPtr(reinterpret_cast(m_block->get_resource_ptr()), m_block); + } else { + return SharedPtr(); + } + } +}; + +template +SharedPtr +MakeShared(Args&&... args) +{ + // allocate one time for space for T and RefBlock + // allocate refblockmakeshared, which includes already size for T + RefBlockMakeShared* block = new RefBlockMakeShared(); + + void* void_ptr = block->get_resource_ptr(); + T* ptr = static_cast(void_ptr); + + // placement new to construct T on this address + try { + ::new (ptr) T(utility::forward(args)...); + + } catch (...) { + delete block; + throw; + } + + // use private constructor to return sharedptr + return SharedPtr(ptr, block); +} + +#endif