diff --git a/include/nvexec/stream/continues_on.cuh b/include/nvexec/stream/continues_on.cuh index a7ee8f260..eb254959d 100644 --- a/include/nvexec/stream/continues_on.cuh +++ b/include/nvexec/stream/continues_on.cuh @@ -172,15 +172,13 @@ namespace nv::execution::_strm template struct source_sender : stream_sender_base { - using schedule_from_sender_t = __result_of; - explicit source_sender(Sender sndr) - : sndr_(schedule_from(static_cast(sndr))) + : sndr_(static_cast(sndr)) {} template <__decay_copyable Self, STDEXEC::receiver Receiver> STDEXEC_EXPLICIT_THIS_BEGIN(auto connect)(this Self&& self, Receiver rcvr) - -> connect_result_t<__copy_cvref_t, Receiver> + -> connect_result_t<__copy_cvref_t, Receiver> { return STDEXEC::connect(static_cast(self).sndr_, static_cast(rcvr)); } @@ -201,7 +199,7 @@ namespace nv::execution::_strm } private: - __result_of sndr_; + Sender sndr_; }; template diff --git a/test/nvexec/CMakeLists.txt b/test/nvexec/CMakeLists.txt index b304800f1..3eca1df94 100644 --- a/test/nvexec/CMakeLists.txt +++ b/test/nvexec/CMakeLists.txt @@ -15,6 +15,7 @@ #============================================================================= set(nvexec_test_sources + continues_on.cpp bulk.cpp ensure_started.cpp start_detached.cpp diff --git a/test/nvexec/continues_on.cpp b/test/nvexec/continues_on.cpp new file mode 100644 index 000000000..852f7d59f --- /dev/null +++ b/test/nvexec/continues_on.cpp @@ -0,0 +1,37 @@ +#include +#include + +#include "nvexec/stream_context.cuh" + +namespace +{ + TEST_CASE("continues on after just", "[cuda][stream][adaptors][continues_on]") + { + nvexec::stream_context ctx; + + auto sndr = STDEXEC::just() | STDEXEC::continues_on(ctx.get_scheduler()); + + STDEXEC::sync_wait(std::move(sndr)); + } + + TEST_CASE("continues on after schedule", "[cuda][stream][adaptors][continues_on]") + { + nvexec::stream_context ctx; + + auto sndr = STDEXEC::schedule(ctx.get_scheduler()) + | STDEXEC::continues_on(ctx.get_scheduler()); + + STDEXEC::sync_wait(std::move(sndr)); + } + + TEST_CASE("continues on twice in a row", "[cuda][stream][adaptors][continues_on]") + { + nvexec::stream_context ctx; + + auto sndr = STDEXEC::just() + | STDEXEC::continues_on(ctx.get_scheduler()) + | STDEXEC::continues_on(ctx.get_scheduler()); + + STDEXEC::sync_wait(std::move(sndr)); + } +} // namespace