diff --git a/src/rpp/rpp/operators/concat.hpp b/src/rpp/rpp/operators/concat.hpp index 60e20d7dd..62fceac9f 100644 --- a/src/rpp/rpp/operators/concat.hpp +++ b/src/rpp/rpp/operators/concat.hpp @@ -12,10 +12,10 @@ #include -#include #include #include +#include #include #include @@ -34,7 +34,8 @@ namespace rpp::operators::details }; template - class concat_disposable final : public rpp::refcount_disposable + class concat_disposable final : public rpp::composite_disposable + , public rpp::details::enable_wrapper_from_this> { public: concat_disposable(TObserver&& observer) @@ -47,7 +48,7 @@ namespace rpp::operators::details std::atomic& stage() { return m_stage; } - void drain(rpp::composite_disposable_wrapper refcounted) + void drain() { while (!is_disposed()) { @@ -55,36 +56,43 @@ namespace rpp::operators::details if (!observable) { stage().store(ConcatStage::None, std::memory_order::relaxed); - refcounted.dispose(); - if (is_disposed()) + if (get_base_child_disposable().is_disposed()) get_observer()->on_completed(); return; } - if (handle_observable_impl(observable.value(), refcounted)) + if (handle_observable_impl(observable.value())) return; } } - void handle_observable(const rpp::constraint::decayed_same_as auto& observable, rpp::composite_disposable_wrapper refcounted) + void handle_observable(const rpp::constraint::decayed_same_as auto& observable) { - if (handle_observable_impl(observable, refcounted)) + if (handle_observable_impl(observable)) return; - drain(refcounted); + drain(); } + rpp::composite_disposable& get_base_child_disposable() { return m_child_disposables[0]; } + rpp::composite_disposable& get_inner_child_disposable() { return m_child_disposables[1]; } + private: - bool handle_observable_impl(const rpp::constraint::decayed_same_as auto& observable, rpp::composite_disposable_wrapper refcounted) + bool handle_observable_impl(const rpp::constraint::decayed_same_as auto& observable) { stage().store(ConcatStage::Draining, std::memory_order::relaxed); - observable.subscribe(concat_inner_observer_strategy{disposable_wrapper_impl{wrapper_from_this()}.lock(), std::move(refcounted)}); + observable.subscribe(concat_inner_observer_strategy{disposable_wrapper_impl{this->wrapper_from_this()}.lock()}); ConcatStage current = ConcatStage::Draining; return stage().compare_exchange_strong(current, ConcatStage::Processing, std::memory_order::seq_cst); } - private: + void composite_dispose_impl(interface_disposable::Mode) noexcept override + { + for (auto& d : m_child_disposables) + d.dispose(); + } + std::optional get_observable() { auto queue = get_queue(); @@ -99,72 +107,58 @@ namespace rpp::operators::details rpp::utils::value_with_mutex m_observer; rpp::utils::value_with_mutex> m_queue; std::atomic m_stage{}; - }; - - template - struct concat_observer_strategy_base - { - concat_observer_strategy_base(std::shared_ptr> disposable, rpp::composite_disposable_wrapper refcounted) - : disposable{std::move(disposable)} - , refcounted{std::move(refcounted)} - { - } - - concat_observer_strategy_base(std::shared_ptr> disposable) - : concat_observer_strategy_base{disposable, disposable->add_ref()} - { - } - - std::shared_ptr> disposable; - rpp::composite_disposable_wrapper refcounted; - - void on_error(const std::exception_ptr& err) const - { - disposable->get_observer()->on_error(err); - } - - void set_upstream(const disposable_wrapper& d) const { refcounted.add(d); } - bool is_disposed() const { return refcounted.is_disposed(); } + std::array m_child_disposables{}; }; template - struct concat_inner_observer_strategy : public concat_observer_strategy_base + struct concat_inner_observer_strategy { static constexpr auto preferred_disposables_mode = rpp::details::observers::disposables_mode::None; - using base = concat_observer_strategy_base; - using base::concat_observer_strategy_base; + std::shared_ptr> disposable{}; + mutable bool locally_disposed{}; template void on_next(T&& v) const { - base::disposable->get_observer()->on_next(std::forward(v)); + disposable->get_observer()->on_next(std::forward(v)); + } + + void on_error(const std::exception_ptr& err) const + { + locally_disposed = true; + disposable->get_observer()->on_error(err); } void on_completed() const { - base::refcounted.clear(); + locally_disposed = true; + disposable->get_inner_child_disposable().clear(); ConcatStage current{ConcatStage::Draining}; - if (base::disposable->stage().compare_exchange_strong(current, ConcatStage::CompletedWhileDraining, std::memory_order::seq_cst)) + if (disposable->stage().compare_exchange_strong(current, ConcatStage::CompletedWhileDraining, std::memory_order::seq_cst)) return; assert(current == ConcatStage::Processing); - base::disposable->drain(base::refcounted); + disposable->drain(); } + + void set_upstream(const disposable_wrapper& d) const { disposable->get_inner_child_disposable().add(d); } + + bool is_disposed() const { return locally_disposed || disposable->get_inner_child_disposable().is_disposed(); } }; template - struct concat_observer_strategy : public concat_observer_strategy_base + struct concat_observer_strategy { - using base = concat_observer_strategy_base; - static constexpr auto preferred_disposables_mode = rpp::details::observers::disposables_mode::None; + std::shared_ptr> disposable; + concat_observer_strategy(TObserver&& observer) - : base{init_state(std::move(observer))} + : disposable{init_state(std::move(observer))} { } @@ -172,19 +166,27 @@ namespace rpp::operators::details void on_next(T&& v) const { ConcatStage current = ConcatStage::None; - if (base::disposable->stage().compare_exchange_strong(current, ConcatStage::Draining, std::memory_order::seq_cst)) - base::disposable->handle_observable(std::forward(v), base::disposable->add_ref()); + if (disposable->stage().compare_exchange_strong(current, ConcatStage::Draining, std::memory_order::seq_cst)) + disposable->handle_observable(std::forward(v)); else - base::disposable->get_queue()->push(std::forward(v)); + disposable->get_queue()->push(std::forward(v)); + } + + void on_error(const std::exception_ptr& err) const + { + disposable->get_observer()->on_error(err); } void on_completed() const { - base::refcounted.dispose(); - if (base::disposable->is_disposed()) - base::disposable->get_observer()->on_completed(); + disposable->get_base_child_disposable().dispose(); + if (disposable->stage() == ConcatStage::None) + disposable->get_observer()->on_completed(); } + void set_upstream(const disposable_wrapper& d) const { disposable->get_base_child_disposable().add(d); } + + bool is_disposed() const { return disposable->get_base_child_disposable().is_disposed(); } private: static std::shared_ptr> init_state(TObserver&& observer) diff --git a/src/tests/rpp/test_concat.cpp b/src/tests/rpp/test_concat.cpp index c8a20cc92..59196f72e 100644 --- a/src/tests/rpp/test_concat.cpp +++ b/src/tests/rpp/test_concat.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -226,6 +227,35 @@ TEST_CASE_TEMPLATE("concat", TestType, rpp::memory_model::use_stack, rpp::memory test([](auto&&... vals) { return rpp::source::just(std::forward(vals).as_dynamic()...) | rpp::ops::concat(); }); + SUBCASE("concat completes right") + { + rpp::subjects::publish_subject> subj{}; + + subj.get_observable() | rpp::ops::concat() | rpp::ops::subscribe(mock); + SUBCASE("on_completed from base") + { + REQUIRE_CALL(*mock, on_completed()).IN_SEQUENCE(s); + subj.get_observer().on_completed(); + } + + SUBCASE("on_completed from inner + then from base") + { + subj.get_observer().on_next(rpp::source::empty()); + + REQUIRE_CALL(*mock, on_completed()).IN_SEQUENCE(s); + subj.get_observer().on_completed(); + } + + SUBCASE("on_completed from base + then from inner") + { + rpp::subjects::publish_subject inner{}; + subj.get_observer().on_next(inner.get_observable()); + subj.get_observer().on_completed(); + + REQUIRE_CALL(*mock, on_completed()).IN_SEQUENCE(s); + inner.get_observer().on_completed(); + } + } } }