diff --git a/bee/async/async.cpp b/bee/async/async.cpp new file mode 100644 index 00000000..d730cb43 --- /dev/null +++ b/bee/async/async.cpp @@ -0,0 +1,28 @@ +#include + +#if defined(__linux__) +# include +# if !defined(BEE_ASYNC_BACKEND_EPOLL) +# include +# endif +#endif + +namespace bee::async { + + std::unique_ptr create() { +#if defined(__linux__) +# if defined(BEE_ASYNC_BACKEND_EPOLL) + return std::make_unique(); +# else + auto uring = std::make_unique(); + if (uring->valid()) { + return uring; + } + return std::make_unique(); +# endif +#else + return std::make_unique(); +#endif + } + +} // namespace bee::async diff --git a/bee/async/async.h b/bee/async/async.h new file mode 100644 index 00000000..1deb7e39 --- /dev/null +++ b/bee/async/async.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#if defined(_WIN32) +# include +#elif defined(__APPLE__) +# if defined(BEE_ASYNC_BACKEND_KQUEUE) +# include +# else +# include +# endif +#elif defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) +# include +#endif + +namespace bee::async { + +#if defined(__linux__) + + class async { + public: + virtual ~async() = default; + virtual bool submit_read(net::fd_t fd, void* buffer, size_t len, uint64_t request_id) = 0; + virtual bool submit_readv(net::fd_t fd, span bufs, uint64_t request_id) = 0; + virtual bool submit_write(net::fd_t fd, const void* buffer, size_t len, uint64_t request_id) = 0; + virtual bool submit_writev(net::fd_t fd, span bufs, uint64_t request_id) = 0; + virtual bool submit_accept(net::fd_t listen_fd, uint64_t request_id) = 0; + virtual bool submit_connect(net::fd_t fd, const net::endpoint& ep, uint64_t request_id) = 0; + virtual bool submit_file_read(file_handle::value_type fd, void* buffer, size_t len, int64_t offset, uint64_t request_id) = 0; + virtual bool submit_file_write(file_handle::value_type fd, const void* buffer, size_t len, int64_t offset, uint64_t request_id) = 0; + virtual bool submit_poll(net::fd_t fd, uint64_t request_id) = 0; + virtual int poll(const span& completions) = 0; + virtual int wait(const span& completions, int timeout) = 0; + virtual void stop() = 0; + virtual void cancel(net::fd_t fd) = 0; + }; + +#endif + + // Factory function: create the async backend for the current platform. + std::unique_ptr create(); + +} // namespace bee::async diff --git a/bee/async/async_bsd.cpp b/bee/async/async_bsd.cpp new file mode 100644 index 00000000..dbe33e01 --- /dev/null +++ b/bee/async/async_bsd.cpp @@ -0,0 +1,490 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace bee::async { + + async::async() + : m_kqfd(-1) { + m_kqfd = kqueue(); + } + + async::~async() { + stop(); + } + + bool async::kqueue_register(net::fd_t fd, int filter, pending_op* op) { + struct kevent ev; + EV_SET(&ev, fd, filter, EV_ADD | EV_ONESHOT, 0, 0, op); + return kevent(m_kqfd, &ev, 1, nullptr, 0, nullptr) == 0; + } + + bool async::submit_read(net::fd_t fd, void* buffer, size_t len, uint64_t request_id) { + auto* op = new pending_op(); + op->request_id = request_id; + op->fd = fd; + op->type = pending_op::read; + op->r.buffer = buffer; + op->r.len = len; + m_pending_ops.insert(op); + if (!kqueue_register(fd, EVFILT_READ, op)) { + m_pending_ops.erase(op); + delete op; + return false; + } + return true; + } + + bool async::submit_readv(net::fd_t fd, span bufs, uint64_t request_id) { + auto* op = new pending_op(); + op->request_id = request_id; + op->fd = fd; + op->type = pending_op::readv; + op->wv = dynarray(bufs.size()); + for (size_t i = 0; i < bufs.size(); ++i) op->wv[i] = bufs[i]; + m_pending_ops.insert(op); + if (!kqueue_register(fd, EVFILT_READ, op)) { + m_pending_ops.erase(op); + delete op; + return false; + } + return true; + } + + bool async::submit_write(net::fd_t fd, const void* buffer, size_t len, uint64_t request_id) { + auto* op = new pending_op(); + op->request_id = request_id; + op->fd = fd; + op->type = pending_op::write; + op->w.buffer = buffer; + op->w.len = len; + m_pending_ops.insert(op); + if (!kqueue_register(fd, EVFILT_WRITE, op)) { + m_pending_ops.erase(op); + delete op; + return false; + } + return true; + } + + bool async::submit_writev(net::fd_t fd, span bufs, uint64_t request_id) { + auto* op = new pending_op(); + op->request_id = request_id; + op->fd = fd; + op->type = pending_op::writev; + op->wv = dynarray(bufs.size()); + for (size_t i = 0; i < bufs.size(); ++i) op->wv[i] = bufs[i]; + m_pending_ops.insert(op); + if (!kqueue_register(fd, EVFILT_WRITE, op)) { + m_pending_ops.erase(op); + delete op; + return false; + } + return true; + } + + bool async::submit_accept(net::fd_t listen_fd, uint64_t request_id) { + auto* op = new pending_op(); + op->request_id = request_id; + op->fd = listen_fd; + op->type = pending_op::accept; + m_pending_ops.insert(op); + if (!kqueue_register(listen_fd, EVFILT_READ, op)) { + m_pending_ops.erase(op); + delete op; + return false; + } + return true; + } + + bool async::submit_connect(net::fd_t fd, const net::endpoint& ep, uint64_t request_id) { + auto status = net::socket::connect(fd, ep); + if (status == net::socket::status::success) { + io_completion c; + c.request_id = request_id; + c.op = async_op::connect; + c.status = async_status::success; + c.bytes_transferred = 0; + c.error_code = 0; + m_sync_completions.push_back(c); + return true; + } + if (status == net::socket::status::wait) { + auto* op = new pending_op(); + op->request_id = request_id; + op->fd = fd; + op->type = pending_op::connect; + m_pending_ops.insert(op); + if (!kqueue_register(fd, EVFILT_WRITE, op)) { + m_pending_ops.erase(op); + delete op; + return false; + } + return true; + } + return false; + } + + bool async::submit_file_read(file_handle::value_type fd, void* buffer, size_t len, int64_t offset, uint64_t request_id) { + io_completion c; + c.request_id = request_id; + c.op = async_op::file_read; + ssize_t n = pread(fd, buffer, len, offset); + if (n >= 0) { + c.status = async_status::success; + c.bytes_transferred = static_cast(n); + c.error_code = 0; + } else { + c.status = async_status::error; + c.bytes_transferred = 0; + c.error_code = errno; + } + m_sync_completions.push_back(c); + return true; + } + + bool async::submit_file_write(file_handle::value_type fd, const void* buffer, size_t len, int64_t offset, uint64_t request_id) { + io_completion c; + c.request_id = request_id; + c.op = async_op::file_write; + ssize_t n = pwrite(fd, buffer, len, offset); + if (n >= 0) { + c.status = async_status::success; + c.bytes_transferred = static_cast(n); + c.error_code = 0; + } else { + c.status = async_status::error; + c.bytes_transferred = 0; + c.error_code = errno; + } + m_sync_completions.push_back(c); + return true; + } + + bool async::submit_poll(net::fd_t fd, uint64_t request_id) { + auto* op = new pending_op(); + op->request_id = request_id; + op->fd = fd; + op->type = pending_op::fd_poll; + m_pending_ops.insert(op); + if (!kqueue_register(fd, EVFILT_READ, op)) { + m_pending_ops.erase(op); + delete op; + return false; + } + return true; + } + + // 重新注册 EV_ONESHOT 事件(用于 spurious wakeup 后保留 op 继续监听)。 + static bool kqueue_rearm(int kqfd, net::fd_t fd, int filter, async::pending_op* op) { + struct kevent ev; + EV_SET(&ev, fd, filter, EV_ADD | EV_ONESHOT, 0, 0, op); + return kevent(kqfd, &ev, 1, nullptr, 0, nullptr) == 0; + } + + // 处理单个 kevent 事件。返回 true 表示产出了有效的 completion,false 表示 + // 遇到了 spurious wakeup(EAGAIN),op 已被重新注册到 kqueue,不产出 completion。 + static bool handle_event(int kqfd, struct kevent& ev, std::unordered_set& pending_ops, io_completion& c) { + auto* op = static_cast(ev.udata); + c.request_id = op->request_id; + c.status = async_status::error; + c.bytes_transferred = 0; + c.error_code = 0; + + net::fd_t fd = op->fd; + + // Check for error flag first (applies to all filter types). + if (ev.flags & EV_ERROR) { + c.error_code = static_cast(ev.data); + switch (op->type) { + case async::pending_op::read: + c.op = async_op::read; + break; + case async::pending_op::readv: + c.op = async_op::readv; + break; + case async::pending_op::write: + c.op = async_op::write; + break; + case async::pending_op::writev: + c.op = async_op::writev; + break; + case async::pending_op::accept: + c.op = async_op::accept; + break; + case async::pending_op::connect: + c.op = async_op::connect; + break; + case async::pending_op::fd_poll: + c.op = async_op::fd_poll; + break; + } + pending_ops.erase(op); + delete op; + return true; + } + + switch (op->type) { + case async::pending_op::read: { + c.op = async_op::read; + // EOF on EVFILT_READ: data == 0 and EV_EOF set. + if ((ev.flags & EV_EOF) && ev.data == 0) { + c.status = async_status::close; + break; + } + int rc = 0; + auto rs = net::socket::recv(fd, rc, static_cast(op->r.buffer), static_cast(op->r.len)); + switch (rs) { + case net::socket::recv_status::success: + c.status = async_status::success; + c.bytes_transferred = static_cast(rc); + break; + case net::socket::recv_status::close: + c.status = async_status::close; + break; + case net::socket::recv_status::failed: + c.status = async_status::error; + c.error_code = errno; + break; + case net::socket::recv_status::wait: + // Spurious wakeup:kevent 触发但实际无数据,重新注册等待下次事件。 + if (kqueue_rearm(kqfd, fd, EVFILT_READ, op)) { + return false; + } + c.status = async_status::error; + c.error_code = EAGAIN; + break; + } + break; + } + case async::pending_op::readv: { + c.op = async_op::readv; + if ((ev.flags & EV_EOF) && ev.data == 0) { + c.status = async_status::close; + break; + } + int rc = 0; + auto rs = net::socket::recvv(fd, rc, span(op->wv.data(), op->wv.size())); + switch (rs) { + case net::socket::recv_status::success: + c.status = async_status::success; + c.bytes_transferred = static_cast(rc); + break; + case net::socket::recv_status::close: + c.status = async_status::close; + break; + case net::socket::recv_status::failed: + c.status = async_status::error; + c.error_code = errno; + break; + case net::socket::recv_status::wait: + if (kqueue_rearm(kqfd, fd, EVFILT_READ, op)) { + return false; + } + c.status = async_status::error; + c.error_code = EAGAIN; + break; + } + break; + } + case async::pending_op::write: { + c.op = async_op::write; + int rc = 0; + auto ss = net::socket::send(fd, rc, static_cast(op->w.buffer), static_cast(op->w.len)); + switch (ss) { + case net::socket::status::success: + c.status = async_status::success; + c.bytes_transferred = static_cast(rc); + break; + case net::socket::status::wait: + if (kqueue_rearm(kqfd, fd, EVFILT_WRITE, op)) { + return false; + } + c.status = async_status::error; + c.error_code = EAGAIN; + break; + case net::socket::status::failed: + c.status = async_status::error; + c.error_code = errno; + break; + } + break; + } + case async::pending_op::writev: { + c.op = async_op::writev; + int rc = 0; + auto ss = net::socket::sendv(fd, rc, span(op->wv.data(), op->wv.size())); + switch (ss) { + case net::socket::status::success: + c.status = async_status::success; + c.bytes_transferred = static_cast(rc); + break; + case net::socket::status::wait: + if (kqueue_rearm(kqfd, fd, EVFILT_WRITE, op)) { + return false; + } + c.status = async_status::error; + c.error_code = EAGAIN; + break; + case net::socket::status::failed: + c.status = async_status::error; + c.error_code = errno; + break; + } + break; + } + case async::pending_op::accept: { + c.op = async_op::accept; + net::fd_t newfd = net::retired_fd; + auto as = net::socket::accept(fd, newfd); + switch (as) { + case net::socket::status::success: + c.status = async_status::success; + c.bytes_transferred = static_cast(newfd); + break; + case net::socket::status::wait: + if (kqueue_rearm(kqfd, fd, EVFILT_READ, op)) { + return false; + } + c.status = async_status::error; + c.error_code = EAGAIN; + break; + case net::socket::status::failed: + c.status = async_status::error; + c.error_code = errno; + break; + } + break; + } + case async::pending_op::connect: { + c.op = async_op::connect; + int err = 0; + if (net::socket::errcode(fd, err) && err == 0) { + c.status = async_status::success; + } else { + c.status = async_status::error; + c.error_code = (err != 0) ? err : errno; + } + break; + } + case async::pending_op::fd_poll: { + // fd_poll: 只通知 fd 可读,不消费任何数据 + c.op = async_op::fd_poll; + c.status = async_status::success; + c.bytes_transferred = 0; + c.error_code = 0; + break; + } + } + + pending_ops.erase(op); + delete op; + return true; + } + + static int drain_kqueue(int kqfd, const span& completions, int timeout_ms, std::unordered_set& pending_ops) { + struct kevent events[async::kMaxEvents]; + int nev; + if (timeout_ms < 0) { + nev = kevent(kqfd, nullptr, 0, events, async::kMaxEvents, nullptr); + } else { + struct timespec ts; + ts.tv_sec = timeout_ms / 1000; + ts.tv_nsec = (timeout_ms % 1000) * 1000000L; + nev = kevent(kqfd, nullptr, 0, events, async::kMaxEvents, &ts); + } + if (nev <= 0) { + return 0; + } + int count = 0; + for (int i = 0; i < nev && count < static_cast(completions.size()); ++i) { + io_completion c; + if (handle_event(kqfd, events[i], pending_ops, c)) { + completions[count++] = c; + } + } + return count; + } + + int async::poll(const span& completions) { + int count = 0; + + while (!m_sync_completions.empty() && count < static_cast(completions.size())) { + completions[count++] = m_sync_completions.front(); + m_sync_completions.pop_front(); + } + if (count >= static_cast(completions.size())) { + return count; + } + + count += drain_kqueue(m_kqfd, span(completions.data() + count, completions.size() - count), 0, m_pending_ops); + return count; + } + + int async::wait(const span& completions, int timeout) { + int count = 0; + + while (!m_sync_completions.empty() && count < static_cast(completions.size())) { + completions[count++] = m_sync_completions.front(); + m_sync_completions.pop_front(); + } + if (count > 0 || completions.size() == 0) { + return count; + } + + count += drain_kqueue(m_kqfd, completions, timeout, m_pending_ops); + return count; + } + + void async::stop() { + if (m_kqfd >= 0) { + close(m_kqfd); + m_kqfd = -1; + } + for (auto* op : m_pending_ops) { + delete op; + } + m_pending_ops.clear(); + } + + static int op_filter(async::pending_op::type_t type) { + switch (type) { + case async::pending_op::read: + case async::pending_op::readv: + case async::pending_op::accept: + case async::pending_op::fd_poll: + return EVFILT_READ; + case async::pending_op::write: + case async::pending_op::writev: + case async::pending_op::connect: + return EVFILT_WRITE; + } + return EVFILT_READ; + } + + void async::cancel(net::fd_t fd) { + for (auto it = m_pending_ops.begin(); it != m_pending_ops.end(); ) { + if ((*it)->fd == fd) { + // 在释放 pending_op 之前,先从 kqueue 中取消注册事件, + // 避免后续 kevent 返回已释放的 udata 指针导致 use-after-free。 + struct kevent ev; + EV_SET(&ev, fd, op_filter((*it)->type), EV_DELETE, 0, 0, nullptr); + kevent(m_kqfd, &ev, 1, nullptr, 0, nullptr); + delete *it; + it = m_pending_ops.erase(it); + } else { + ++it; + } + } + } + +} // namespace bee::async diff --git a/bee/async/async_bsd.h b/bee/async/async_bsd.h new file mode 100644 index 00000000..93a7e0e5 --- /dev/null +++ b/bee/async/async_bsd.h @@ -0,0 +1,75 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace bee::net { + struct endpoint; +} + +namespace bee::async { + + class async { + public: + async(); + ~async(); + + bool submit_read(net::fd_t fd, void* buffer, size_t len, uint64_t request_id); + bool submit_readv(net::fd_t fd, span bufs, uint64_t request_id); + bool submit_write(net::fd_t fd, const void* buffer, size_t len, uint64_t request_id); + bool submit_writev(net::fd_t fd, span bufs, uint64_t request_id); + bool submit_accept(net::fd_t listen_fd, uint64_t request_id); + bool submit_connect(net::fd_t fd, const net::endpoint& ep, uint64_t request_id); + bool submit_file_read(file_handle::value_type fd, void* buffer, size_t len, int64_t offset, uint64_t request_id); + bool submit_file_write(file_handle::value_type fd, const void* buffer, size_t len, int64_t offset, uint64_t request_id); + bool submit_poll(net::fd_t fd, uint64_t request_id); + int poll(const span& completions); + int wait(const span& completions, int timeout); + void stop(); + void cancel(net::fd_t fd); + + struct pending_op { + uint64_t request_id = 0; + net::fd_t fd = net::retired_fd; + enum type_t : uint8_t { + read, + readv, + write, + writev, + accept, + connect, + fd_poll, + } type = read; + union { + struct { + void* buffer; + size_t len; + } r; + struct { + const void* buffer; + size_t len; + } w; + }; + dynarray wv; // used when type == writev or readv + }; + + static constexpr int kMaxEvents = 64; + + private: + int m_kqfd; + std::deque m_sync_completions; + std::unordered_set m_pending_ops; // all registered but not yet fired ops + + bool kqueue_register(net::fd_t fd, int filter, pending_op* op); + }; + +} // namespace bee::async diff --git a/bee/async/async_epoll_linux.cpp b/bee/async/async_epoll_linux.cpp new file mode 100644 index 00000000..0796994d --- /dev/null +++ b/bee/async/async_epoll_linux.cpp @@ -0,0 +1,568 @@ +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace bee::async { + + async_epoll::async_epoll() + : m_epfd(-1) { + m_epfd = epoll_create1(EPOLL_CLOEXEC); + } + + async_epoll::~async_epoll() { + stop(); + } + + // Build combined events mask from fd_state and call epoll_ctl ADD or MOD. + bool async_epoll::fd_arm(net::fd_t fd, fd_state& state) { + uint32_t events = 0; + if (state.read_op) events |= EPOLLIN; + if (state.write_op) events |= EPOLLOUT; + + if (events == state.events) { + return true; + } + + struct epoll_event ev; + memset(&ev, 0, sizeof(ev)); + ev.events = events; + ev.data.fd = fd; + + int op = (state.events == 0) ? EPOLL_CTL_ADD : EPOLL_CTL_MOD; + if (epoll_ctl(m_epfd, op, fd, &ev) != 0) { + return false; + } + state.events = events; + return true; + } + + // Remove one direction from fd_state; update or DEL epoll registration. + void async_epoll::fd_disarm(net::fd_t fd, fd_state& state, bool is_write) { + if (is_write) { + state.write_op = nullptr; + } else { + state.read_op = nullptr; + } + + uint32_t new_events = 0; + if (state.read_op) new_events |= EPOLLIN; + if (state.write_op) new_events |= EPOLLOUT; + + if (new_events == 0) { + epoll_ctl(m_epfd, EPOLL_CTL_DEL, fd, nullptr); + state.events = 0; + m_fd_states.erase(fd); + } else if (new_events != state.events) { + struct epoll_event ev; + memset(&ev, 0, sizeof(ev)); + ev.events = new_events; + ev.data.fd = fd; + epoll_ctl(m_epfd, EPOLL_CTL_MOD, fd, &ev); + state.events = new_events; + } + } + + bool async_epoll::submit_read(net::fd_t fd, void* buffer, size_t len, uint64_t request_id) { + auto& state = m_fd_states[fd]; + if (state.read_op) return false; // 同一 fd 读方向最多一个 in-flight op + + auto* op = new pending_op(); + op->request_id = request_id; + op->fd = fd; + op->type = pending_op::read; + op->r.buffer = buffer; + op->r.len = len; + + state.read_op = op; + if (!fd_arm(fd, state)) { + state.read_op = nullptr; + if (!state.write_op) m_fd_states.erase(fd); + delete op; + return false; + } + return true; + } + + bool async_epoll::submit_readv(net::fd_t fd, span bufs, uint64_t request_id) { + auto& state = m_fd_states[fd]; + if (state.read_op) return false; + + auto* op = new pending_op(); + op->request_id = request_id; + op->fd = fd; + op->type = pending_op::readv; + op->wv = dynarray(bufs.size()); + for (size_t i = 0; i < bufs.size(); ++i) op->wv[i] = bufs[i]; + + state.read_op = op; + if (!fd_arm(fd, state)) { + state.read_op = nullptr; + if (!state.write_op) m_fd_states.erase(fd); + delete op; + return false; + } + return true; + } + + bool async_epoll::submit_write(net::fd_t fd, const void* buffer, size_t len, uint64_t request_id) { + auto& state = m_fd_states[fd]; + if (state.write_op) return false; // 同一 fd 写方向最多一个 in-flight op + + auto* op = new pending_op(); + op->request_id = request_id; + op->fd = fd; + op->type = pending_op::write; + op->w.buffer = buffer; + op->w.len = len; + + state.write_op = op; + if (!fd_arm(fd, state)) { + state.write_op = nullptr; + if (!state.read_op) m_fd_states.erase(fd); + delete op; + return false; + } + return true; + } + + bool async_epoll::submit_writev(net::fd_t fd, span bufs, uint64_t request_id) { + auto& state = m_fd_states[fd]; + if (state.write_op) return false; // 同一 fd 写方向最多一个 in-flight op + + auto* op = new pending_op(); + op->request_id = request_id; + op->fd = fd; + op->type = pending_op::writev; + op->wv = dynarray(bufs.size()); + for (size_t i = 0; i < bufs.size(); ++i) op->wv[i] = bufs[i]; + + state.write_op = op; + if (!fd_arm(fd, state)) { + state.write_op = nullptr; + if (!state.read_op) m_fd_states.erase(fd); + delete op; + return false; + } + return true; + } + + bool async_epoll::submit_accept(net::fd_t listen_fd, uint64_t request_id) { + auto& state = m_fd_states[listen_fd]; + if (state.read_op) return false; // 同一 fd 读方向最多一个 in-flight op + + auto* op = new pending_op(); + op->request_id = request_id; + op->fd = listen_fd; + op->type = pending_op::accept; + + state.read_op = op; + if (!fd_arm(listen_fd, state)) { + state.read_op = nullptr; + if (!state.write_op) m_fd_states.erase(listen_fd); + delete op; + return false; + } + return true; + } + + bool async_epoll::submit_connect(net::fd_t fd, const net::endpoint& ep, uint64_t request_id) { + auto status = net::socket::connect(fd, ep); + if (status == net::socket::status::success) { + io_completion c; + c.request_id = request_id; + c.op = async_op::connect; + c.status = async_status::success; + c.bytes_transferred = 0; + c.error_code = 0; + m_sync_completions.push_back(c); + return true; + } + if (status == net::socket::status::wait) { + auto& state = m_fd_states[fd]; + if (state.write_op) return false; // 同一 fd 写方向最多一个 in-flight op + + auto* op = new pending_op(); + op->request_id = request_id; + op->fd = fd; + op->type = pending_op::connect; + + state.write_op = op; + if (!fd_arm(fd, state)) { + state.write_op = nullptr; + if (!state.read_op) m_fd_states.erase(fd); + delete op; + return false; + } + return true; + } + return false; + } + + bool async_epoll::submit_file_read(file_handle::value_type fd, void* buffer, size_t len, int64_t offset, uint64_t request_id) { + io_completion c; + c.request_id = request_id; + c.op = async_op::file_read; + ssize_t n = pread(fd, buffer, len, offset); + if (n >= 0) { + c.status = async_status::success; + c.bytes_transferred = static_cast(n); + c.error_code = 0; + } else { + c.status = async_status::error; + c.bytes_transferred = 0; + c.error_code = errno; + } + m_sync_completions.push_back(c); + return true; + } + + bool async_epoll::submit_file_write(file_handle::value_type fd, const void* buffer, size_t len, int64_t offset, uint64_t request_id) { + io_completion c; + c.request_id = request_id; + c.op = async_op::file_write; + ssize_t n = pwrite(fd, buffer, len, offset); + if (n >= 0) { + c.status = async_status::success; + c.bytes_transferred = static_cast(n); + c.error_code = 0; + } else { + c.status = async_status::error; + c.bytes_transferred = 0; + c.error_code = errno; + } + m_sync_completions.push_back(c); + return true; + } + + bool async_epoll::submit_poll(net::fd_t fd, uint64_t request_id) { + auto& state = m_fd_states[fd]; + if (state.read_op) return false; // 同一 fd 读方向最多一个 in-flight op + + auto* op = new pending_op(); + op->request_id = request_id; + op->fd = fd; + op->type = pending_op::fd_poll; + + state.read_op = op; + if (!fd_arm(fd, state)) { + state.read_op = nullptr; + if (!state.write_op) m_fd_states.erase(fd); + delete op; + return false; + } + return true; + } + + // Process one read-direction event. Returns true if a completion was produced. + // If the syscall returns EAGAIN (spurious wakeup), returns false and leaves op intact. + static bool process_read_op( + int epfd, + net::fd_t fd, + async_epoll::fd_state& state, + std::unordered_map& fd_states, + io_completion& out + ) { + async_epoll::pending_op* op = state.read_op; + out.request_id = op->request_id; + out.bytes_transferred = 0; + out.error_code = 0; + + bool produced = true; + + switch (op->type) { + case async_epoll::pending_op::read: { + out.op = async_op::read; + int rc = 0; + auto rs = net::socket::recv(fd, rc, static_cast(op->r.buffer), static_cast(op->r.len)); + switch (rs) { + case net::socket::recv_status::success: + out.status = async_status::success; + out.bytes_transferred = static_cast(rc); + break; + case net::socket::recv_status::close: + out.status = async_status::close; + break; + case net::socket::recv_status::failed: + out.status = async_status::error; + out.error_code = errno; + break; + case net::socket::recv_status::wait: + // Spurious EPOLLIN: no data yet, keep op, do not emit completion. + produced = false; + break; + } + break; + } + case async_epoll::pending_op::readv: { + out.op = async_op::readv; + int rc = 0; + auto rs = net::socket::recvv(fd, rc, span(op->wv.data(), op->wv.size())); + switch (rs) { + case net::socket::recv_status::success: + out.status = async_status::success; + out.bytes_transferred = static_cast(rc); + break; + case net::socket::recv_status::close: + out.status = async_status::close; + break; + case net::socket::recv_status::failed: + out.status = async_status::error; + out.error_code = errno; + break; + case net::socket::recv_status::wait: + produced = false; + break; + } + break; + } + case async_epoll::pending_op::accept: { + out.op = async_op::accept; + net::fd_t newfd = net::retired_fd; + auto as = net::socket::accept(fd, newfd); + switch (as) { + case net::socket::status::success: + out.status = async_status::success; + out.bytes_transferred = static_cast(newfd); + break; + case net::socket::status::wait: + produced = false; + break; + case net::socket::status::failed: + out.status = async_status::error; + out.error_code = errno; + break; + } + break; + } + case async_epoll::pending_op::fd_poll: { + // fd_poll: 只通知 fd 可读,不消费任何数据 + out.op = async_op::fd_poll; + out.status = async_status::success; + out.bytes_transferred = 0; + out.error_code = 0; + break; + } + default: + produced = false; + break; + } + + if (produced) { + // Completion ready: clear read slot and update epoll mask. + delete op; + state.read_op = nullptr; + uint32_t new_events = state.write_op ? EPOLLOUT : 0; + if (new_events == 0) { + epoll_ctl(epfd, EPOLL_CTL_DEL, fd, nullptr); + state.events = 0; + fd_states.erase(fd); + } else if (new_events != state.events) { + struct epoll_event ev; + memset(&ev, 0, sizeof(ev)); + ev.events = new_events; + ev.data.fd = fd; + epoll_ctl(epfd, EPOLL_CTL_MOD, fd, &ev); + state.events = new_events; + } + } + return produced; + } + + // Process one write-direction event. Returns true if a completion was produced. + static bool process_write_op( + int epfd, + net::fd_t fd, + async_epoll::fd_state& state, + std::unordered_map& fd_states, + io_completion& out + ) { + async_epoll::pending_op* op = state.write_op; + out.request_id = op->request_id; + out.bytes_transferred = 0; + out.error_code = 0; + + bool produced = true; + + switch (op->type) { + case async_epoll::pending_op::write: { + out.op = async_op::write; + int rc = 0; + auto ss = net::socket::send(fd, rc, static_cast(op->w.buffer), static_cast(op->w.len)); + switch (ss) { + case net::socket::status::success: + out.status = async_status::success; + out.bytes_transferred = static_cast(rc); + break; + case net::socket::status::wait: + produced = false; + break; + case net::socket::status::failed: + out.status = async_status::error; + out.error_code = errno; + break; + } + break; + } + case async_epoll::pending_op::writev: { + out.op = async_op::writev; + int rc = 0; + auto ss = net::socket::sendv(fd, rc, span(op->wv.data(), op->wv.size())); + switch (ss) { + case net::socket::status::success: + out.status = async_status::success; + out.bytes_transferred = static_cast(rc); + break; + case net::socket::status::wait: + produced = false; + break; + case net::socket::status::failed: + out.status = async_status::error; + out.error_code = errno; + break; + } + break; + } + case async_epoll::pending_op::connect: { + out.op = async_op::connect; + int err = 0; + if (net::socket::errcode(fd, err) && err == 0) { + out.status = async_status::success; + } else { + out.status = async_status::error; + out.error_code = (err != 0) ? err : errno; + } + break; + } + default: + produced = false; + break; + } + + if (produced) { + delete op; + state.write_op = nullptr; + uint32_t new_events = state.read_op ? EPOLLIN : 0; + if (new_events == 0) { + epoll_ctl(epfd, EPOLL_CTL_DEL, fd, nullptr); + state.events = 0; + fd_states.erase(fd); + } else if (new_events != state.events) { + struct epoll_event ev; + memset(&ev, 0, sizeof(ev)); + ev.events = new_events; + ev.data.fd = fd; + epoll_ctl(epfd, EPOLL_CTL_MOD, fd, &ev); + state.events = new_events; + } + } + return produced; + } + + static int drain_epoll( + int epfd, + const span& completions, + int timeout_ms, + std::unordered_map& fd_states + ) { + constexpr int kMaxEvents = 64; + struct epoll_event events[kMaxEvents]; + int nfds = epoll_wait(epfd, events, kMaxEvents, timeout_ms); + if (nfds <= 0) { + return 0; + } + + int count = 0; + for (int i = 0; i < nfds && count < static_cast(completions.size()); ++i) { + uint32_t ev = events[i].events; + net::fd_t fd = static_cast(events[i].data.fd); + + auto it = fd_states.find(fd); + if (it == fd_states.end()) continue; + auto& state = it->second; + + // Propagate errors to both directions. + if (ev & (EPOLLERR | EPOLLHUP)) { + ev |= EPOLLIN | EPOLLOUT; + } + + // Process read direction. + if ((ev & EPOLLIN) && state.read_op) { + io_completion c; + if (process_read_op(epfd, fd, state, fd_states, c)) { + completions[count++] = c; + } + } + + // Process write direction (re-lookup in case erase happened above). + if (count < static_cast(completions.size()) && (ev & EPOLLOUT)) { + auto it2 = fd_states.find(fd); + if (it2 != fd_states.end() && it2->second.write_op) { + io_completion c; + if (process_write_op(epfd, fd, it2->second, fd_states, c)) { + completions[count++] = c; + } + } + } + } + return count; + } + + int async_epoll::poll(const span& completions) { + int count = 0; + + while (!m_sync_completions.empty() && count < static_cast(completions.size())) { + completions[count++] = m_sync_completions.front(); + m_sync_completions.pop_front(); + } + if (count >= static_cast(completions.size())) { + return count; + } + + count += drain_epoll(m_epfd, span(completions.data() + count, completions.size() - count), 0, m_fd_states); + return count; + } + + int async_epoll::wait(const span& completions, int timeout) { + int count = 0; + + while (!m_sync_completions.empty() && count < static_cast(completions.size())) { + completions[count++] = m_sync_completions.front(); + m_sync_completions.pop_front(); + } + if (count > 0 || completions.size() == 0) { + return count; + } + + count += drain_epoll(m_epfd, completions, timeout, m_fd_states); + return count; + } + + void async_epoll::stop() { + if (m_epfd >= 0) { + close(m_epfd); + m_epfd = -1; + } + for (auto& [fd, state] : m_fd_states) { + delete state.read_op; + delete state.write_op; + } + m_fd_states.clear(); + } + + void async_epoll::cancel(net::fd_t fd) { + auto it = m_fd_states.find(fd); + if (it == m_fd_states.end()) return; + auto& state = it->second; + epoll_ctl(m_epfd, EPOLL_CTL_DEL, fd, nullptr); + delete state.read_op; + delete state.write_op; + m_fd_states.erase(it); + } + +} // namespace bee::async diff --git a/bee/async/async_epoll_linux.h b/bee/async/async_epoll_linux.h new file mode 100644 index 00000000..acb748c9 --- /dev/null +++ b/bee/async/async_epoll_linux.h @@ -0,0 +1,88 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace bee::net { + struct endpoint; +} + +namespace bee::async { + + class async_epoll : public async { + public: + async_epoll(); + ~async_epoll() override; + + bool submit_read(net::fd_t fd, void* buffer, size_t len, uint64_t request_id) override; + bool submit_readv(net::fd_t fd, span bufs, uint64_t request_id) override; + bool submit_write(net::fd_t fd, const void* buffer, size_t len, uint64_t request_id) override; + bool submit_writev(net::fd_t fd, span bufs, uint64_t request_id) override; + bool submit_accept(net::fd_t listen_fd, uint64_t request_id) override; + bool submit_connect(net::fd_t fd, const net::endpoint& ep, uint64_t request_id) override; + bool submit_file_read(file_handle::value_type fd, void* buffer, size_t len, int64_t offset, uint64_t request_id) override; + bool submit_file_write(file_handle::value_type fd, const void* buffer, size_t len, int64_t offset, uint64_t request_id) override; + bool submit_poll(net::fd_t fd, uint64_t request_id) override; + int poll(const span& completions) override; + int wait(const span& completions, int timeout) override; + void stop() override; + void cancel(net::fd_t fd) override; + + struct pending_op { + uint64_t request_id = 0; + net::fd_t fd = net::retired_fd; + enum type_t : uint8_t { + read, + readv, + write, + writev, + accept, + connect, + fd_poll, + } type = read; + union { + struct { + void* buffer; + size_t len; + } r; + struct { + const void* buffer; + size_t len; + } w; + }; + dynarray wv; // used when type == writev or readv + }; + + // Per-fd state: tracks up to one read-direction and one write-direction pending op, + // plus the currently registered epoll events mask. + struct fd_state { + pending_op* read_op = nullptr; // EPOLLIN direction (read/accept/connect) + pending_op* write_op = nullptr; // EPOLLOUT direction (write/writev) + uint32_t events = 0; // currently registered events mask + }; + + private: + int m_epfd; + std::deque m_sync_completions; + std::unordered_map m_fd_states; + + static constexpr int kMaxEvents = 64; + + // Register or update epoll for fd, merging read_op/write_op into a combined event mask. + // Returns false on epoll_ctl failure. + bool fd_arm(net::fd_t fd, fd_state& state); + + // Remove one direction from fd_state; DEL if both directions gone. + void fd_disarm(net::fd_t fd, fd_state& state, bool is_write); + }; + +} // namespace bee::async diff --git a/bee/async/async_osx.cpp b/bee/async/async_osx.cpp new file mode 100644 index 00000000..0a07dd82 --- /dev/null +++ b/bee/async/async_osx.cpp @@ -0,0 +1,431 @@ +#include +#include +#include +#include +#include + +#include +#include + +namespace bee::async { + + async::async() + : m_queue(dispatch_queue_create("bee.async.gcd", DISPATCH_QUEUE_SERIAL)) + , m_signal(dispatch_semaphore_create(0)) + , m_stopped(false) {} + + async::~async() { + stop(); + } + + void async::enqueue_completion(const io_completion& c) { + m_completions.push_back(c); + dispatch_semaphore_signal(m_signal); + } + + int async::drain(const span& completions) { + int count = 0; + while (!m_completions.empty() && count < static_cast(completions.size())) { + completions[count++] = m_completions.front(); + m_completions.pop_front(); + } + return count; + } + + // Get or create per-fd sources. Called from Lua thread before submit. + async::fd_sources* async::get_or_create(net::fd_t fd) { + auto it = m_fd_map.find(fd); + if (it != m_fd_map.end()) return it->second; + + auto* s = new fd_sources(); + m_fd_map[fd] = s; + + // Create persistent read source (starts suspended). + s->read_src = dispatch_source_create(DISPATCH_SOURCE_TYPE_READ, fd, 0, m_queue); + if (!s->read_src) { delete s; m_fd_map.erase(fd); return nullptr; } + dispatch_source_set_event_handler(s->read_src, ^{ + if (!s->r.pending) return; + io_completion c; + c.request_id = s->r.request_id; + c.error_code = 0; + if (s->r.is_readv) { + c.op = async_op::readv; + int rc = 0; + auto rs = net::socket::recvv(fd, rc, span(s->r.iov.data(), s->r.iov.size())); + switch (rs) { + case net::socket::recv_status::success: + c.status = async_status::success; + c.bytes_transferred = static_cast(rc); + break; + case net::socket::recv_status::close: + c.status = async_status::close; + c.bytes_transferred = 0; + break; + case net::socket::recv_status::wait: + return; // spurious wakeup, stay resumed + case net::socket::recv_status::failed: + c.status = async_status::error; + c.bytes_transferred = 0; + c.error_code = errno; + break; + } + } else { + c.op = async_op::read; + int rc = 0; + auto rs = net::socket::recv(fd, rc, static_cast(s->r.buffer), static_cast(s->r.len)); + switch (rs) { + case net::socket::recv_status::success: + c.status = async_status::success; + c.bytes_transferred = static_cast(rc); + break; + case net::socket::recv_status::close: + c.status = async_status::close; + c.bytes_transferred = 0; + break; + case net::socket::recv_status::wait: + return; // spurious wakeup, stay resumed + case net::socket::recv_status::failed: + c.status = async_status::error; + c.bytes_transferred = 0; + c.error_code = errno; + break; + } + } + s->r.pending = false; + dispatch_suspend(s->read_src); + this->enqueue_completion(c); + }); + dispatch_source_set_cancel_handler(s->read_src, ^{ + dispatch_release(s->read_src); + s->read_src = nullptr; + }); + + // Create persistent write source (starts suspended). + s->write_src = dispatch_source_create(DISPATCH_SOURCE_TYPE_WRITE, fd, 0, m_queue); + if (!s->write_src) { + dispatch_source_cancel(s->read_src); + delete s; m_fd_map.erase(fd); return nullptr; + } + dispatch_source_set_event_handler(s->write_src, ^{ + if (!s->w.pending) return; + io_completion c; + c.request_id = s->w.request_id; + c.error_code = 0; + if (s->w.is_writev) { + c.op = async_op::writev; + int rc = 0; + auto ss = net::socket::sendv(fd, rc, span(s->w.iov.data(), s->w.iov.size())); + switch (ss) { + case net::socket::status::success: + c.status = async_status::success; + c.bytes_transferred = static_cast(rc); + break; + case net::socket::status::wait: + return; + case net::socket::status::failed: + c.status = async_status::error; + c.bytes_transferred = 0; + c.error_code = errno; + break; + } + } else { + c.op = async_op::write; + int rc = 0; + auto ss = net::socket::send(fd, rc, static_cast(s->w.buffer), static_cast(s->w.len)); + switch (ss) { + case net::socket::status::success: + c.status = async_status::success; + c.bytes_transferred = static_cast(rc); + break; + case net::socket::status::wait: + return; + case net::socket::status::failed: + c.status = async_status::error; + c.bytes_transferred = 0; + c.error_code = errno; + break; + } + } + s->w.pending = false; + dispatch_suspend(s->write_src); + this->enqueue_completion(c); + }); + dispatch_source_set_cancel_handler(s->write_src, ^{ + dispatch_release(s->write_src); + s->write_src = nullptr; + }); + + // Both sources start suspended; resume only on demand. + return s; + } + + void async::release_fd(net::fd_t fd) { + auto it = m_fd_map.find(fd); + if (it == m_fd_map.end()) return; + fd_sources* s = it->second; + m_fd_map.erase(it); + // Cancel sources; cancel handlers will release and set ptr to nullptr. + if (s->read_src) { + // Must resume before cancel if suspended, otherwise cancel blocks. + if (!s->r.pending) dispatch_resume(s->read_src); + dispatch_source_cancel(s->read_src); + } + if (s->write_src) { + if (!s->w.pending) dispatch_resume(s->write_src); + dispatch_source_cancel(s->write_src); + } + // Delete the fd_sources struct after dispatch finishes on m_queue. + dispatch_async(m_queue, ^{ delete s; }); + } + + bool async::associate(net::fd_t fd) { + return get_or_create(fd) != nullptr; + } + + bool async::submit_read(net::fd_t fd, void* buffer, size_t len, uint64_t request_id) { + fd_sources* s = get_or_create(fd); + if (!s || s->r.pending) return false; + s->r.buffer = buffer; + s->r.len = len; + s->r.request_id = request_id; + s->r.pending = true; + s->r.is_readv = false; + dispatch_resume(s->read_src); + return true; + } + + bool async::submit_readv(net::fd_t fd, span bufs, uint64_t request_id) { + fd_sources* s = get_or_create(fd); + if (!s || s->r.pending) return false; + s->r.iov = dynarray(bufs.size()); + for (size_t i = 0; i < bufs.size(); ++i) s->r.iov[i] = bufs[i]; + s->r.request_id = request_id; + s->r.pending = true; + s->r.is_readv = true; + dispatch_resume(s->read_src); + return true; + } + + bool async::submit_write(net::fd_t fd, const void* buffer, size_t len, uint64_t request_id) { + fd_sources* s = get_or_create(fd); + if (!s || s->w.pending) return false; + s->w.buffer = buffer; + s->w.len = len; + s->w.request_id = request_id; + s->w.pending = true; + s->w.is_writev = false; + dispatch_resume(s->write_src); + return true; + } + + bool async::submit_writev(net::fd_t fd, span bufs, uint64_t request_id) { + fd_sources* s = get_or_create(fd); + if (!s || s->w.pending) return false; + s->w.iov = dynarray(bufs.size()); + for (size_t i = 0; i < bufs.size(); ++i) s->w.iov[i] = bufs[i]; + s->w.request_id = request_id; + s->w.pending = true; + s->w.is_writev = true; + dispatch_resume(s->write_src); + return true; + } + + bool async::submit_accept(net::fd_t listen_fd, uint64_t request_id) { + fd_sources* s = get_or_create(listen_fd); + if (!s || s->r.pending) return false; + s->r.buffer = nullptr; + s->r.len = 0; + s->r.request_id = request_id; + s->r.pending = true; + // Reuse read source; override event handler for accept. + // Use a separate accept source to avoid conflating read/accept. + // For simplicity, create a one-shot accept source (accept is rare). + // (accept happens once per connection, so one-shot overhead is negligible) + s->r.pending = false; // revert, use one-shot path below + + dispatch_source_t src = dispatch_source_create(DISPATCH_SOURCE_TYPE_READ, listen_fd, 0, m_queue); + if (!src) return false; + dispatch_source_set_event_handler(src, ^{ + net::fd_t newfd = net::retired_fd; + auto as = net::socket::accept(listen_fd, newfd); + io_completion c; + c.request_id = request_id; + c.op = async_op::accept; + switch (as) { + case net::socket::status::success: + c.status = async_status::success; + c.bytes_transferred = static_cast(newfd); + c.error_code = 0; + break; + case net::socket::status::wait: + return; + case net::socket::status::failed: + c.status = async_status::error; + c.bytes_transferred = 0; + c.error_code = errno; + break; + } + dispatch_source_cancel(src); + this->enqueue_completion(c); + }); + dispatch_source_set_cancel_handler(src, ^{ dispatch_release(src); }); + dispatch_resume(src); + return true; + } + + bool async::submit_connect(net::fd_t fd, const net::endpoint& ep, uint64_t request_id) { + auto status = net::socket::connect(fd, ep); + if (status == net::socket::status::success) { + dispatch_async(m_queue, ^{ + io_completion c; + c.request_id = request_id; + c.op = async_op::connect; + c.status = async_status::success; + c.bytes_transferred = 0; + c.error_code = 0; + this->enqueue_completion(c); + }); + return true; + } + if (status == net::socket::status::wait) { + // Connect completion comes as write-ready; use one-shot (rare). + dispatch_source_t src = dispatch_source_create(DISPATCH_SOURCE_TYPE_WRITE, fd, 0, m_queue); + if (!src) return false; + dispatch_source_set_event_handler(src, ^{ + int err = 0; + io_completion c; + c.request_id = request_id; + c.op = async_op::connect; + if (net::socket::errcode(fd, err) && err == 0) { + c.status = async_status::success; + c.bytes_transferred = 0; + c.error_code = 0; + } else { + c.status = async_status::error; + c.bytes_transferred = 0; + c.error_code = (err != 0) ? err : errno; + } + dispatch_source_cancel(src); + this->enqueue_completion(c); + }); + dispatch_source_set_cancel_handler(src, ^{ dispatch_release(src); }); + dispatch_resume(src); + return true; + } + return false; + } + + bool async::submit_file_read(file_handle::value_type fd, void* buffer, size_t len, int64_t offset, uint64_t request_id) { + dispatch_async(m_queue, ^{ + io_completion c; + c.request_id = request_id; + c.op = async_op::file_read; + ssize_t n = pread(fd, buffer, len, offset); + if (n >= 0) { + c.status = async_status::success; + c.bytes_transferred = static_cast(n); + c.error_code = 0; + } else { + c.status = async_status::error; + c.bytes_transferred = 0; + c.error_code = errno; + } + this->enqueue_completion(c); + }); + return true; + } + + bool async::submit_file_write(file_handle::value_type fd, const void* buffer, size_t len, int64_t offset, uint64_t request_id) { + dispatch_async(m_queue, ^{ + io_completion c; + c.request_id = request_id; + c.op = async_op::file_write; + ssize_t n = pwrite(fd, buffer, len, offset); + if (n >= 0) { + c.status = async_status::success; + c.bytes_transferred = static_cast(n); + c.error_code = 0; + } else { + c.status = async_status::error; + c.bytes_transferred = 0; + c.error_code = errno; + } + this->enqueue_completion(c); + }); + return true; + } + + void async::cancel(net::fd_t fd) { + release_fd(fd); + } + + bool async::submit_poll(net::fd_t fd, uint64_t request_id) { + dispatch_source_t source = dispatch_source_create(DISPATCH_SOURCE_TYPE_READ, fd, 0, m_queue); + if (!source) { + return false; + } + + dispatch_source_set_event_handler(source, ^{ + // fd_poll: 只通知 fd 可读,不消费任何数据 + io_completion c; + c.request_id = request_id; + c.op = async_op::fd_poll; + c.status = async_status::success; + c.bytes_transferred = 0; + c.error_code = 0; + dispatch_source_cancel(source); + this->enqueue_completion(c); + }); + + dispatch_source_set_cancel_handler(source, ^{ + dispatch_release(source); + }); + + dispatch_resume(source); + return true; + } + + int async::poll(const span& completions) { + __block int count = 0; + dispatch_sync(m_queue, ^{ count = drain(completions); }); + for (int i = 0; i < count; ++i) + dispatch_semaphore_wait(m_signal, DISPATCH_TIME_NOW); + return count; + } + + int async::wait(const span& completions, int timeout) { + dispatch_time_t when = (timeout < 0) + ? DISPATCH_TIME_FOREVER + : dispatch_time(DISPATCH_TIME_NOW, static_cast(timeout) * NSEC_PER_MSEC); + if (dispatch_semaphore_wait(m_signal, when) != 0) return 0; + __block int count = 0; + dispatch_sync(m_queue, ^{ count = drain(completions); }); + for (int i = 1; i < count; ++i) + dispatch_semaphore_wait(m_signal, DISPATCH_TIME_NOW); + return count; + } + + void async::stop() { + if (!m_stopped) { + m_stopped = true; + // Cancel all fd sources. + for (auto& [fd, s] : m_fd_map) { + if (s->read_src) { if (!s->r.pending) dispatch_resume(s->read_src); dispatch_source_cancel(s->read_src); } + if (s->write_src) { if (!s->w.pending) dispatch_resume(s->write_src); dispatch_source_cancel(s->write_src); } + } + if (m_queue) { + dispatch_sync(m_queue, ^{ + for (auto& [fd, s] : m_fd_map) delete s; + m_fd_map.clear(); + }); + dispatch_release(m_queue); + m_queue = nullptr; + } + if (m_signal) { + dispatch_release(m_signal); + m_signal = nullptr; + } + } + } + +} // namespace bee::async diff --git a/bee/async/async_osx.h b/bee/async/async_osx.h new file mode 100644 index 00000000..92d65ddd --- /dev/null +++ b/bee/async/async_osx.h @@ -0,0 +1,83 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace bee::net { + struct endpoint; +} + +namespace bee::async { + + class async { + public: + async(); + ~async(); + + bool submit_read(net::fd_t fd, void* buffer, size_t len, uint64_t request_id); + bool submit_readv(net::fd_t fd, span bufs, uint64_t request_id); + bool submit_write(net::fd_t fd, const void* buffer, size_t len, uint64_t request_id); + bool submit_writev(net::fd_t fd, span bufs, uint64_t request_id); + bool submit_accept(net::fd_t listen_fd, uint64_t request_id); + bool submit_connect(net::fd_t fd, const net::endpoint& ep, uint64_t request_id); + bool submit_file_read(file_handle::value_type fd, void* buffer, size_t len, int64_t offset, uint64_t request_id); + bool submit_file_write(file_handle::value_type fd, const void* buffer, size_t len, int64_t offset, uint64_t request_id); + bool associate(net::fd_t fd); + bool submit_poll(net::fd_t fd, uint64_t request_id); + int poll(const span& completions); + int wait(const span& completions, int timeout); + void cancel(net::fd_t fd); + void stop(); + + private: + dispatch_queue_t m_queue; + dispatch_semaphore_t m_signal; + std::deque m_completions; // only accessed on m_queue + bool m_stopped; + + // Per-fd persistent read/write sources (reused via suspend/resume) + struct fd_sources { + dispatch_source_t read_src = nullptr; + dispatch_source_t write_src = nullptr; + // Current pending op for each direction (updated before resume) + struct read_op { + void* buffer = nullptr; + size_t len = 0; + uint64_t request_id = 0; + bool pending = false; + bool is_readv = false; + dynarray iov; + } r; + struct write_op { + const void* buffer = nullptr; + size_t len = 0; + uint64_t request_id = 0; + bool pending = false; + bool is_writev = false; + dynarray iov; + } w; + }; + + std::unordered_map m_fd_map; + + fd_sources* get_or_create(net::fd_t fd); + void release_fd(net::fd_t fd); + + // Enqueue a completion from the GCD queue thread. + void enqueue_completion(const io_completion& c); + + // Drain completions into output span. + int drain(const span& completions); + }; + +} // namespace bee::async diff --git a/bee/async/async_types.h b/bee/async/async_types.h new file mode 100644 index 00000000..afb2f12f --- /dev/null +++ b/bee/async/async_types.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include + +namespace bee::async { + + enum class async_status : uint8_t { + success, + close, + error, + cancel, + }; + + enum class async_op : uint8_t { + read, + readv, + write, + writev, + accept, + connect, + file_read, + file_write, + fd_poll, + timeout, // internal: IORING_OP_TIMEOUT fallback, never surfaced to caller + }; + + struct io_completion { + uint64_t request_id; + async_status status; + async_op op; + size_t bytes_transferred; + int error_code; + }; + +} // namespace bee::async diff --git a/bee/async/async_uring_linux.cpp b/bee/async/async_uring_linux.cpp new file mode 100644 index 00000000..2d31e4df --- /dev/null +++ b/bee/async/async_uring_linux.cpp @@ -0,0 +1,672 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +// ---- io_uring ABI definitions (no dependency on liburing or ) ---- +// +// All constants and struct layouts are taken directly from the Linux UAPI headers +// (linux/io_uring.h) and verified against the kernel source. Static assertions +// below guard the struct layouts so a mismatch is caught at compile time. + +#ifndef __NR_io_uring_setup +# define __NR_io_uring_setup 425 +#endif +#ifndef __NR_io_uring_enter +# define __NR_io_uring_enter 426 +#endif + +// io_uring_setup flags +enum { + BEE__IORING_SETUP_NO_SQARRAY = 0x10000u, // kernel 6.6+: sq_array is implicit +}; + +// io_uring feature flags (returned in io_uring_params.features) +enum { + BEE__IORING_FEAT_SINGLE_MMAP = 1u, // SQ+CQ share a single mmap region + BEE__IORING_FEAT_NODROP = 2u, // CQ overflow is never silently dropped +}; + +// io_uring_enter flags +enum { + BEE__IORING_ENTER_GETEVENTS = 1u, + BEE__IORING_ENTER_EXT_ARG = 8u, // arg is io_uring_getevents_arg (kernel 5.11+) +}; + +// sq_ring flags (iou->sqflags) +enum { + BEE__IORING_SQ_CQ_OVERFLOW = 2u, +}; + +// Opcodes we use +enum { + BEE__IORING_OP_ACCEPT = 13, + BEE__IORING_OP_CONNECT = 16, + BEE__IORING_OP_READ = 22, + BEE__IORING_OP_WRITE = 23, + BEE__IORING_OP_SEND = 26, + BEE__IORING_OP_RECV = 27, + BEE__IORING_OP_SENDMSG = 9, + BEE__IORING_OP_RECVMSG = 10, + BEE__IORING_OP_POLL_ADD = 6, + BEE__IORING_OP_TIMEOUT = 11, // kernel 5.4+ +}; + +struct bee__io_sqring_offsets { + uint32_t head; + uint32_t tail; + uint32_t ring_mask; + uint32_t ring_entries; + uint32_t flags; + uint32_t dropped; + uint32_t array; + uint32_t reserved0; + uint64_t reserved1; +}; +static_assert(40 == sizeof(bee__io_sqring_offsets), "sqring_offsets size"); + +struct bee__io_cqring_offsets { + uint32_t head; + uint32_t tail; + uint32_t ring_mask; + uint32_t ring_entries; + uint32_t overflow; + uint32_t cqes; + uint64_t reserved0; + uint64_t reserved1; +}; +static_assert(40 == sizeof(bee__io_cqring_offsets), "cqring_offsets size"); + +struct bee__io_uring_sqe { + uint8_t opcode; + uint8_t flags; + uint16_t ioprio; + int32_t fd; + union { + uint64_t off; + uint64_t addr2; + }; + union { + uint64_t addr; + }; + uint32_t len; + union { + uint32_t rw_flags; + uint32_t fsync_flags; + uint32_t open_flags; + uint32_t statx_flags; + uint32_t accept_flags; // used by IORING_OP_ACCEPT + uint32_t msg_flags; // used by IORING_OP_SEND / RECV + }; + uint64_t user_data; + union { + uint16_t buf_index; + uint64_t pad[3]; + }; +}; +static_assert(64 == sizeof(bee__io_uring_sqe), "sqe size"); +static_assert(0 == __builtin_offsetof(bee__io_uring_sqe, opcode), "sqe.opcode"); +static_assert(4 == __builtin_offsetof(bee__io_uring_sqe, fd), "sqe.fd"); +static_assert(8 == __builtin_offsetof(bee__io_uring_sqe, off), "sqe.off"); +static_assert(16 == __builtin_offsetof(bee__io_uring_sqe, addr), "sqe.addr"); +static_assert(24 == __builtin_offsetof(bee__io_uring_sqe, len), "sqe.len"); +static_assert(28 == __builtin_offsetof(bee__io_uring_sqe, rw_flags), "sqe.rw_flags"); +static_assert(32 == __builtin_offsetof(bee__io_uring_sqe, user_data), "sqe.user_data"); +static_assert(40 == __builtin_offsetof(bee__io_uring_sqe, buf_index), "sqe.buf_index"); + +struct bee__io_uring_cqe { + uint64_t user_data; + int32_t res; + uint32_t flags; +}; +static_assert(16 == sizeof(bee__io_uring_cqe), "cqe size"); + +struct bee__io_uring_params { + uint32_t sq_entries; + uint32_t cq_entries; + uint32_t flags; + uint32_t sq_thread_cpu; + uint32_t sq_thread_idle; + uint32_t features; + uint32_t reserved[4]; + bee__io_sqring_offsets sq_off; // 40 bytes + bee__io_cqring_offsets cq_off; // 40 bytes +}; +static_assert(40 + 40 + 40 == sizeof(bee__io_uring_params), "params size"); +static_assert(40 == __builtin_offsetof(bee__io_uring_params, sq_off), "params.sq_off"); +static_assert(80 == __builtin_offsetof(bee__io_uring_params, cq_off), "params.cq_off"); + +// Used with IORING_ENTER_EXT_ARG to pass a timeout directly to io_uring_enter. +struct bee__io_uring_getevents_arg { + uint64_t sigmask; + uint32_t sigmask_sz; + uint32_t pad; + uint64_t ts; // pointer to __kernel_timespec +}; + +struct bee__kernel_timespec { + int64_t tv_sec; + int64_t tv_nsec; +}; + +// ---- raw syscall wrappers ---- + +static inline int sys_io_uring_setup(unsigned entries, bee__io_uring_params* p) noexcept { + return static_cast(syscall(__NR_io_uring_setup, entries, p)); +} + +static inline int sys_io_uring_enter(int fd, unsigned to_submit, unsigned min_complete, unsigned flags, const void* arg) noexcept { + const unsigned arg_size = (flags & BEE__IORING_ENTER_EXT_ARG) + ? static_cast(sizeof(bee__io_uring_getevents_arg)) + : 0u; + return static_cast(syscall(__NR_io_uring_enter, fd, to_submit, min_complete, flags, arg, arg_size)); +} + +// ---- io_uring ring state (kept behind the forward-declared pointer in the header) ---- + +// Context kept alive on the heap for the duration of a SENDMSG operation. +// msghdr.msg_iov points into bufs[], so both must outlive the CQE. +struct writev_ctx { + struct msghdr msg = {}; + bee::dynarray bufs; + explicit writev_ctx(bee::span src) + : bufs(src.size()) { + for (size_t i = 0; i < src.size(); ++i) bufs[i] = src[i]; + msg.msg_iov = reinterpret_cast(bufs.data()); + msg.msg_iovlen = static_cast(src.size()); + } +}; + +struct readv_ctx { + struct msghdr msg = {}; + bee::dynarray bufs; + explicit readv_ctx(bee::span src) + : bufs(src.size()) { + for (size_t i = 0; i < src.size(); ++i) bufs[i] = src[i]; + msg.msg_iov = reinterpret_cast(bufs.data()); + msg.msg_iovlen = static_cast(src.size()); + } +}; + +struct io_uring { + int ringfd = -1; + char* sq = nullptr; // base of the shared SQ+CQ mmap + size_t maxlen = 0; + bee__io_uring_sqe* sqe = nullptr; + size_t sqelen = 0; + + // SQ ring pointers into sq mmap + uint32_t* sqhead = nullptr; // kernel consumer + uint32_t* sqtail = nullptr; // we publish here + uint32_t* sqflags = nullptr; // SQ_NEED_WAKEUP / SQ_CQ_OVERFLOW flags + uint32_t sqmask = 0; + + // CQ ring pointers into sq mmap + uint32_t* cqhead = nullptr; // we advance (consumer) + uint32_t* cqtail = nullptr; // kernel publishes here + uint32_t cqmask = 0; + bee__io_uring_cqe* cqes = nullptr; + + // Runtime capability flag: IORING_ENTER_EXT_ARG is supported (kernel 5.11+). + // Probed on first use; false means we fall back to IORING_OP_TIMEOUT SQE. + bool ext_arg_supported = true; + + // Pending writev contexts keyed by request_id, freed when CQE arrives. + std::unordered_map> writev_pending; + // Pending readv contexts keyed by request_id, freed when CQE arrives. + std::unordered_map> readv_pending; +}; + +namespace bee::async { + + static constexpr uint32_t kEntries = 256; + + // Pack op type into the high 8 bits of user_data; request_id uses the low 56 bits. + static constexpr uint64_t kOpShift = 56; + static constexpr uint64_t kIdMask = (uint64_t(1) << kOpShift) - 1; + + static inline uint64_t pack_user_data(async_op op, uint64_t request_id) noexcept { + return (static_cast(op) << kOpShift) | (request_id & kIdMask); + } + + static inline async_op unpack_op(uint64_t user_data) noexcept { + return static_cast(user_data >> kOpShift); + } + + static inline uint64_t unpack_id(uint64_t user_data) noexcept { + return user_data & kIdMask; + } + + // ---- atomic helpers (matching libuv's acquire/release ordering) ---- + + static inline uint32_t load_acquire(const uint32_t* p) noexcept { + return __atomic_load_n(p, __ATOMIC_ACQUIRE); + } + + static inline void store_release(uint32_t* p, uint32_t v) noexcept { + __atomic_store_n(p, v, __ATOMIC_RELEASE); + } + + // ---- ring init / exit ---- + + static bool uring_init(uint32_t entries, io_uring* ring) noexcept { + bee__io_uring_params params; + memset(¶ms, 0, sizeof(params)); + + // On kernel 6.6+ the kernel can omit the sq_array indirection via + // IORING_SETUP_NO_SQARRAY. We intentionally do not request that flag here: + // unknown setup flags may be rejected on older kernels, and the existing + // sq_array initialisation path already works for both layouts. + int ringfd = sys_io_uring_setup(entries, ¶ms); + if (ringfd < 0) return false; + + // Require only the features that are actually used below: + // SINGLE_MMAP (Linux 5.4+) and NODROP (Linux 5.5+). + if (!(params.features & BEE__IORING_FEAT_SINGLE_MMAP)) { + close(ringfd); + return false; + } + if (!(params.features & BEE__IORING_FEAT_NODROP)) { + close(ringfd); + return false; + } + + // SQ+CQ share one mmap (SINGLE_MMAP): use the larger of the two regions. + size_t sqlen = params.sq_off.array + params.sq_entries * sizeof(uint32_t); + size_t cqlen = params.cq_off.cqes + params.cq_entries * sizeof(bee__io_uring_cqe); + size_t maxlen = sqlen < cqlen ? cqlen : sqlen; + size_t sqelen = params.sq_entries * sizeof(bee__io_uring_sqe); + + char* sq = static_cast( + mmap(nullptr, maxlen, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_POPULATE, ringfd, 0 /* IORING_OFF_SQ_RING */) + ); + if (sq == MAP_FAILED) { + close(ringfd); + return false; + } + + bee__io_uring_sqe* sqe_ptr = static_cast( + mmap(nullptr, sqelen, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_POPULATE, ringfd, 0x10000000ull /* IORING_OFF_SQES */) + ); + if (sqe_ptr == MAP_FAILED) { + munmap(sq, maxlen); + close(ringfd); + return false; + } + + ring->ringfd = ringfd; + ring->sq = sq; + ring->maxlen = maxlen; + ring->sqe = sqe_ptr; + ring->sqelen = sqelen; + + ring->sqhead = reinterpret_cast(sq + params.sq_off.head); + ring->sqtail = reinterpret_cast(sq + params.sq_off.tail); + ring->sqflags = reinterpret_cast(sq + params.sq_off.flags); + ring->sqmask = *reinterpret_cast(sq + params.sq_off.ring_mask); + + ring->cqhead = reinterpret_cast(sq + params.cq_off.head); + ring->cqtail = reinterpret_cast(sq + params.cq_off.tail); + ring->cqmask = *reinterpret_cast(sq + params.cq_off.ring_mask); + ring->cqes = reinterpret_cast(sq + params.cq_off.cqes); + + // Pre-fill sq_array with the identity mapping (slot i -> SQE i). + // On kernels that set NO_SQARRAY the kernel ignores this array, but + // populating it is harmless and keeps a single code path. + if (!(params.flags & BEE__IORING_SETUP_NO_SQARRAY)) { + uint32_t* sqarray = reinterpret_cast(sq + params.sq_off.array); + for (uint32_t i = 0; i <= ring->sqmask; i++) + sqarray[i] = i; + } + + return true; + } + + static void uring_exit(io_uring* ring) noexcept { + if (ring->ringfd < 0) return; + munmap(ring->sqe, ring->sqelen); + munmap(ring->sq, ring->maxlen); + close(ring->ringfd); + ring->ringfd = -1; + } + + // ---- SQE helpers ---- + + // Returns the next free SQE slot, or nullptr if the SQ is full. + // The caller fills the SQE and then calls uring_submit(). + static inline bee__io_uring_sqe* uring_get_sqe(io_uring* ring) noexcept { + uint32_t head = load_acquire(ring->sqhead); + uint32_t tail = *ring->sqtail; + uint32_t mask = ring->sqmask; + + // Ring is full only when the number of in-flight SQEs reaches capacity. + if ((tail - head) >= (mask + 1)) + return nullptr; + + uint32_t slot = tail & mask; + bee__io_uring_sqe* sqe = &ring->sqe[slot]; + memset(sqe, 0, sizeof(*sqe)); + return sqe; + } + + // Publish one new SQE to the kernel by advancing sqtail (release ordering). + // If SQPOLL is not in use this is sufficient; io_uring_enter drives submission. + static inline void uring_submit(io_uring* ring) noexcept { + store_release(ring->sqtail, *ring->sqtail + 1); + } + + // Return the number of SQEs published but not yet consumed by the kernel. + static inline uint32_t uring_pending(const io_uring* ring) noexcept { + return *ring->sqtail - load_acquire(ring->sqhead); + } + + // ---- CQE harvesting ---- + + int async_uring::harvest_cqes(const span& completions) noexcept { + io_uring* ring = m_ring; + uint32_t head = *ring->cqhead; + uint32_t tail = load_acquire(ring->cqtail); + uint32_t mask = ring->cqmask; + uint32_t count = 0; + + while (head != tail && count < static_cast(completions.size())) { + const bee__io_uring_cqe& cqe = ring->cqes[head & mask]; + // Skip internal timeout CQEs — they are never surfaced to the caller. + if (unpack_op(cqe.user_data) == async_op::timeout) { + head++; + continue; + } + io_completion& c = completions[count++]; + c.op = unpack_op(cqe.user_data); + c.request_id = unpack_id(cqe.user_data); + // Free the writev context (msghdr + iobuf array) once the CQE arrives. + if (c.op == async_op::writev) { + ring->writev_pending.erase(c.request_id); + } + // Free the readv context once the CQE arrives. + if (c.op == async_op::readv) { + ring->readv_pending.erase(c.request_id); + } + // For connect/file_write/accept/fd_poll, res==0 means success (not EOF). + // For read/write (recv/send), res==0 means the peer closed the connection. + bool zero_is_success = (c.op == async_op::connect || c.op == async_op::writev || c.op == async_op::file_write || c.op == async_op::accept || c.op == async_op::fd_poll); + if (cqe.res > 0) { + c.status = async_status::success; + // For fd_poll, cqe.res is the revents mask (e.g. POLLIN=1), not a byte count. + c.bytes_transferred = (c.op == async_op::fd_poll) ? 0 : static_cast(cqe.res); + c.error_code = 0; + } else if (cqe.res == 0) { + if (zero_is_success) { + c.status = async_status::success; + c.bytes_transferred = 0; + c.error_code = 0; + } else { + c.status = async_status::close; + c.bytes_transferred = 0; + c.error_code = 0; + } + } else { + c.status = async_status::error; + c.bytes_transferred = 0; + c.error_code = -cqe.res; + } + head++; + } + + if (count > 0) + store_release(ring->cqhead, head); + + // If the CQ overflowed, poke the kernel to flush the overflow list. + // We don't grab the new entries here — they'll appear in the next poll/wait. + if (load_acquire(ring->sqflags) & BEE__IORING_SQ_CQ_OVERFLOW) { + int rc; + do { + rc = sys_io_uring_enter(ring->ringfd, 0, 0, BEE__IORING_ENTER_GETEVENTS, nullptr); + } while (rc == -1 && errno == EINTR); + } + + return static_cast(count); + } + + // ---- async_uring public interface ---- + + async_uring::async_uring() + : m_ring(new io_uring {}) { + if (!uring_init(kEntries, m_ring)) { + delete m_ring; + m_ring = nullptr; + } + } + + async_uring::~async_uring() { + stop(); + } + + bool async_uring::submit_read(net::fd_t fd, void* buffer, size_t len, uint64_t request_id) { + if (!m_ring) return false; + bee__io_uring_sqe* sqe = uring_get_sqe(m_ring); + if (!sqe) return false; + sqe->opcode = BEE__IORING_OP_RECV; + sqe->fd = fd; + sqe->addr = reinterpret_cast(buffer); + sqe->len = static_cast(len); + sqe->msg_flags = 0; + sqe->user_data = pack_user_data(async_op::read, request_id); + uring_submit(m_ring); + return true; // SQE queued; will be submitted on next poll/wait + } + + bool async_uring::submit_readv(net::fd_t fd, span bufs, uint64_t request_id) { + if (!m_ring) return false; + bee__io_uring_sqe* sqe = uring_get_sqe(m_ring); + if (!sqe) return false; + auto ctx = std::make_unique(bufs); + sqe->opcode = BEE__IORING_OP_RECVMSG; + sqe->fd = fd; + sqe->addr = reinterpret_cast(&ctx->msg); + sqe->len = 1; + sqe->msg_flags = 0; + sqe->user_data = pack_user_data(async_op::readv, request_id); + m_ring->readv_pending.emplace(request_id, std::move(ctx)); + uring_submit(m_ring); + return true; + } + + bool async_uring::submit_write(net::fd_t fd, const void* buffer, size_t len, uint64_t request_id) { + if (!m_ring) return false; + bee__io_uring_sqe* sqe = uring_get_sqe(m_ring); + if (!sqe) return false; + sqe->opcode = BEE__IORING_OP_SEND; + sqe->fd = fd; + sqe->addr = reinterpret_cast(buffer); + sqe->len = static_cast(len); + sqe->msg_flags = 0; + sqe->user_data = pack_user_data(async_op::write, request_id); + uring_submit(m_ring); + return true; // SQE queued; will be submitted on next poll/wait + } + + bool async_uring::submit_writev(net::fd_t fd, span bufs, uint64_t request_id) { + if (!m_ring) return false; + bee__io_uring_sqe* sqe = uring_get_sqe(m_ring); + if (!sqe) return false; + auto ctx = std::make_unique(bufs); + sqe->opcode = BEE__IORING_OP_SENDMSG; + sqe->fd = fd; + sqe->addr = reinterpret_cast(&ctx->msg); + sqe->len = 1; + sqe->msg_flags = 0; + sqe->user_data = pack_user_data(async_op::writev, request_id); + m_ring->writev_pending.emplace(request_id, std::move(ctx)); + uring_submit(m_ring); + return true; + } + + bool async_uring::submit_accept(net::fd_t listen_fd, uint64_t request_id) { + if (!m_ring) return false; + bee__io_uring_sqe* sqe = uring_get_sqe(m_ring); + if (!sqe) return false; + sqe->opcode = BEE__IORING_OP_ACCEPT; + sqe->fd = listen_fd; + sqe->addr = 0; // don't capture peer address + sqe->addr2 = 0; // no socklen_t output + sqe->accept_flags = SOCK_NONBLOCK | SOCK_CLOEXEC; + sqe->user_data = pack_user_data(async_op::accept, request_id); + uring_submit(m_ring); + return true; // SQE queued; will be submitted on next poll/wait + } + + bool async_uring::submit_connect(net::fd_t fd, const net::endpoint& ep, uint64_t request_id) { + if (!m_ring) return false; + bee__io_uring_sqe* sqe = uring_get_sqe(m_ring); + if (!sqe) return false; + // The caller (Lua binding) pins the endpoint in the buf table, guaranteeing + // ep.addr() remains valid until the CQE is harvested. + sqe->opcode = BEE__IORING_OP_CONNECT; + sqe->fd = fd; + sqe->addr = reinterpret_cast(ep.addr()); + sqe->off = ep.addrlen(); // CONNECT stores addrlen in the off field + sqe->user_data = pack_user_data(async_op::connect, request_id); + uring_submit(m_ring); + return true; // SQE queued; will be submitted on next poll/wait + } + + bool async_uring::submit_file_read(file_handle::value_type fd, void* buffer, size_t len, int64_t offset, uint64_t request_id) { + if (!m_ring) return false; + bee__io_uring_sqe* sqe = uring_get_sqe(m_ring); + if (!sqe) return false; + sqe->opcode = BEE__IORING_OP_READ; + sqe->fd = fd; + sqe->addr = reinterpret_cast(buffer); + sqe->len = static_cast(len); + sqe->off = static_cast(offset); + sqe->user_data = pack_user_data(async_op::file_read, request_id); + uring_submit(m_ring); + return true; // SQE queued; will be submitted on next poll/wait + } + + bool async_uring::submit_file_write(file_handle::value_type fd, const void* buffer, size_t len, int64_t offset, uint64_t request_id) { + if (!m_ring) return false; + bee__io_uring_sqe* sqe = uring_get_sqe(m_ring); + if (!sqe) return false; + sqe->opcode = BEE__IORING_OP_WRITE; + sqe->fd = fd; + sqe->addr = reinterpret_cast(buffer); + sqe->len = static_cast(len); + sqe->off = static_cast(offset); + sqe->user_data = pack_user_data(async_op::file_write, request_id); + uring_submit(m_ring); + return true; // SQE queued; will be submitted on next poll/wait + } + + bool async_uring::submit_poll(net::fd_t fd, uint64_t request_id) { + if (!m_ring) return false; + bee__io_uring_sqe* sqe = uring_get_sqe(m_ring); + if (!sqe) return false; + sqe->opcode = BEE__IORING_OP_POLL_ADD; + sqe->fd = fd; + sqe->rw_flags = POLLIN; // 监听可读事件 + sqe->user_data = pack_user_data(async_op::fd_poll, request_id); + uring_submit(m_ring); + return true; + } + + int async_uring::poll(const span& completions) { + if (!m_ring) return 0; + return wait(completions, 0); + } + + int async_uring::wait(const span& completions, int timeout) { + if (!m_ring) return 0; + // Submit any pending SQEs and wait for at least one CQE in a single syscall. + uint32_t pending = uring_pending(m_ring); + + if (timeout == 0) { + // Non-blocking: flush pending SQEs then harvest whatever is already done. + if (pending > 0) { + int ret; + do { + ret = sys_io_uring_enter(m_ring->ringfd, pending, 0, 0, nullptr); + } while (ret == -1 && errno == EINTR); + } + return harvest_cqes(completions); + } else if (timeout > 0) { + if (m_ring->ext_arg_supported) { + // Fast path (kernel 5.11+): pass timeout directly to io_uring_enter. + bee__kernel_timespec ts; + ts.tv_sec = timeout / 1000; + ts.tv_nsec = static_cast(timeout % 1000) * 1000000L; + bee__io_uring_getevents_arg arg; + memset(&arg, 0, sizeof(arg)); + arg.ts = reinterpret_cast(&ts); + int ret; + do { + ret = sys_io_uring_enter(m_ring->ringfd, pending, 1, BEE__IORING_ENTER_GETEVENTS | BEE__IORING_ENTER_EXT_ARG, &arg); + } while (ret == -1 && errno == EINTR); + if (ret == -1 && errno == EINVAL) { + // Kernel does not support EXT_ARG; disable and fall through to TIMEOUT SQE path. + m_ring->ext_arg_supported = false; + } else { + // errno == ETIME: timeout expired with 0 completions; harvest anyway. + return harvest_cqes(completions); + } + } + if (!m_ring->ext_arg_supported) { + // Fallback for kernel 5.4-5.10: submit a TIMEOUT SQE alongside any + // pending SQEs, then block until either an IO CQE or the timeout fires. + bee__io_uring_sqe* sqe = uring_get_sqe(m_ring); + if (sqe) { + bee__kernel_timespec ts; + ts.tv_sec = timeout / 1000; + ts.tv_nsec = static_cast(timeout % 1000) * 1000000L; + memset(sqe, 0, sizeof(*sqe)); + sqe->opcode = BEE__IORING_OP_TIMEOUT; + sqe->addr = reinterpret_cast(&ts); + sqe->len = 1; // min_complete: fire after 1 other CQE or on expiry + sqe->user_data = pack_user_data(async_op::timeout, 0); + uring_submit(m_ring); + pending = uring_pending(m_ring); + } + int ret; + do { + ret = sys_io_uring_enter(m_ring->ringfd, pending, 1, BEE__IORING_ENTER_GETEVENTS, nullptr); + } while (ret == -1 && errno == EINTR); + } + } else { + // Block until at least one CQE is available, submitting pending SQEs atomically. + int ret; + do { + ret = sys_io_uring_enter(m_ring->ringfd, pending, 1, BEE__IORING_ENTER_GETEVENTS, nullptr); + } while (ret == -1 && errno == EINTR); + } + + return harvest_cqes(completions); + } + + void async_uring::stop() { + if (m_ring) { + uring_exit(m_ring); + delete m_ring; + m_ring = nullptr; + } + } + + void async_uring::cancel(net::fd_t /*fd*/) { + // 这是一个空操作。对于 io_uring,这个调用不会取消挂起的操作。 + // 当文件描述符关闭时,这些操作将被内核取消,并最终以 ECANCELED 错误完成。 + // 要实现立即取消,需要使用 IORING_OP_ASYNC_CANCEL。 + } + +} // namespace bee::async diff --git a/bee/async/async_uring_linux.h b/bee/async/async_uring_linux.h new file mode 100644 index 00000000..20a17706 --- /dev/null +++ b/bee/async/async_uring_linux.h @@ -0,0 +1,43 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +// io_uring is defined internally in the .cpp; forward-declare the ring type here. +struct io_uring; + +namespace bee::async { + + class async_uring : public async { + public: + async_uring(); + ~async_uring() override; + + bool submit_read(net::fd_t fd, void* buffer, size_t len, uint64_t request_id) override; + bool submit_readv(net::fd_t fd, span bufs, uint64_t request_id) override; + bool submit_write(net::fd_t fd, const void* buffer, size_t len, uint64_t request_id) override; + bool submit_writev(net::fd_t fd, span bufs, uint64_t request_id) override; + bool submit_accept(net::fd_t listen_fd, uint64_t request_id) override; + bool submit_connect(net::fd_t fd, const net::endpoint& ep, uint64_t request_id) override; + bool submit_file_read(file_handle::value_type fd, void* buffer, size_t len, int64_t offset, uint64_t request_id) override; + bool submit_file_write(file_handle::value_type fd, const void* buffer, size_t len, int64_t offset, uint64_t request_id) override; + bool submit_poll(net::fd_t fd, uint64_t request_id) override; + int poll(const span& completions) override; + int wait(const span& completions, int timeout) override; + void stop() override; + void cancel(net::fd_t fd) override; + + bool valid() const noexcept { return m_ring != nullptr; } + + private: + io_uring* m_ring; // nullptr if ring initialisation failed + + int harvest_cqes(const span& completions) noexcept; + }; + +} // namespace bee::async diff --git a/bee/async/async_win.cpp b/bee/async/async_win.cpp new file mode 100644 index 00000000..e579f6f2 --- /dev/null +++ b/bee/async/async_win.cpp @@ -0,0 +1,533 @@ +// clang-format off +// WinSock2.h must be included before Windows.h to avoid winsock.h type redefinitions. +#include +// clang-format on +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +// RtlNtStatusToDosError is in ntdll but not declared in standard MinGW headers. +extern "C" ULONG WINAPI RtlNtStatusToDosError(NTSTATUS Status); + +// AcceptEx and ConnectEx are Winsock extension functions loaded dynamically via +// WSAIoctl(SIO_GET_EXTENSION_FUNCTION_POINTER). We declare only the GUIDs and +// function-pointer typedefs we need so that MSWSock.h is not required. +typedef BOOL(PASCAL* LPFN_ACCEPTEX)(SOCKET, SOCKET, PVOID, DWORD, DWORD, DWORD, LPDWORD, LPOVERLAPPED); +typedef BOOL(PASCAL* LPFN_CONNECTEX)(SOCKET, const sockaddr*, int, PVOID, DWORD, LPDWORD, LPOVERLAPPED); +static const GUID k_WSAID_ACCEPTEX = { 0xb5367df1, 0xcbac, 0x11cf, { 0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92 } }; +static const GUID k_WSAID_CONNECTEX = { 0x25a207b9, 0xddf3, 0x4660, { 0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e } }; + + +namespace bee::async { + + // sizeof(OVERLAPPED) is 32 on both x86 and x64. + static_assert(sizeof(OVERLAPPED) <= sizeof(async::overlapped_ext::overlapped), "overlapped_ext::overlapped buffer is too small"); + + static inline OVERLAPPED* as_ov(async::overlapped_ext* ext) { + return reinterpret_cast(ext->overlapped); + } + + async::async() + : m_iocp(nullptr) + , m_connectex(nullptr) + , m_acceptex(nullptr) { + m_iocp = CreateIoCompletionPort(INVALID_HANDLE_VALUE, nullptr, 0, 0); + + // Load ConnectEx and AcceptEx at startup using a temporary socket. + SOCKET tmp = WSASocketW(AF_INET, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, WSA_FLAG_OVERLAPPED); + if (tmp != INVALID_SOCKET) { + DWORD bytes = 0; +GUID guid_connectex = WSAID_CONNECTEX; + WSAIoctl(tmp, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid_connectex, sizeof(guid_connectex), &m_connectex, sizeof(m_connectex), &bytes, nullptr, nullptr); +GUID guid_acceptex = WSAID_ACCEPTEX; + WSAIoctl(tmp, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid_acceptex, sizeof(guid_acceptex), &m_acceptex, sizeof(m_acceptex), &bytes, nullptr, nullptr); + closesocket(tmp); + } + } + + async::~async() { + stop(); + } + + void async::cancel(net::fd_t fd) { + CancelIoEx(reinterpret_cast(fd), nullptr); + } + + bool async::associate(net::fd_t fd) { + auto* h = reinterpret_cast(fd); + if (CreateIoCompletionPort(h, static_cast(m_iocp), 0, 0) != m_iocp) { + // ERROR_INVALID_PARAMETER means the handle is already associated with + // this same IOCP (e.g. the listen socket used across multiple accepts). + // Any other error is a real failure. + if (GetLastError() != ERROR_INVALID_PARAMETER) { + return false; + } + // Already associated: notification modes already set, nothing more to do. + return true; + } + // Newly associated: skip IOCP notification on synchronous completion and + // skip setting the handle event — all completions come through the IOCP. + SetFileCompletionNotificationModes(h, FILE_SKIP_COMPLETION_PORT_ON_SUCCESS | FILE_SKIP_SET_EVENT_ON_HANDLE); + return true; + } + + // Re-opens a file handle with FILE_FLAG_OVERLAPPED and associates it with + // the IOCP. Returns {overlapped file_handle, writable} on success, or + // {invalid file_handle, false} on failure. + std::pair async::associate_file(file_handle::value_type fd) { + auto* h = static_cast(fd); + bool writable = true; + HANDLE ov_h = ReOpenFile(h, GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, FILE_FLAG_OVERLAPPED); + if (ov_h == INVALID_HANDLE_VALUE) { + writable = false; + ov_h = ReOpenFile(h, GENERIC_READ, FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, FILE_FLAG_OVERLAPPED); + } + if (ov_h == INVALID_HANDLE_VALUE) return { file_handle {}, false }; + if (CreateIoCompletionPort(ov_h, static_cast(m_iocp), 0, 0) != m_iocp) { + if (GetLastError() != ERROR_INVALID_PARAMETER) { + CloseHandle(ov_h); + return { file_handle {}, false }; + } + } + return { file_handle::from_native(ov_h), writable }; + } + + // Handle a synchronous completion (FILE_SKIP_COMPLETION_PORT_ON_SUCCESS): + // build an io_completion from the overlapped_ext and either store it in + // completions[] or append to m_sync_completions for later delivery. + static io_completion make_sync_completion(std::unique_ptr ext, DWORD bytes) { + io_completion c; + c.request_id = ext->request_id; + c.bytes_transferred = static_cast(bytes); + c.error_code = 0; + static constexpr async_op op_map[] = { + async_op::read, + async_op::readv, + async_op::write, + async_op::writev, + async_op::accept, + async_op::connect, + async_op::file_read, + async_op::file_write, + async_op::fd_poll, + }; + c.op = op_map[static_cast(ext->type)]; + if (bytes == 0 && + ext->type != async::overlapped_ext::op_connect && + ext->type != async::overlapped_ext::op_file_read && + ext->type != async::overlapped_ext::op_file_write && + ext->type != async::overlapped_ext::op_poll) { + c.status = async_status::close; + } else { + c.status = async_status::success; + } + return c; + } + + bool async::submit_read(net::fd_t fd, void* buffer, size_t len, uint64_t request_id) { + auto ext = std::make_unique(); + memset(ext->overlapped, 0, sizeof(ext->overlapped)); + ext->request_id = request_id; + ext->type = overlapped_ext::op_read; + ext->accept_sock = 0; + + WSABUF buf; + buf.buf = static_cast(buffer); + buf.len = static_cast(len); + + DWORD flags = 0; + DWORD bytes = 0; + int rc = WSARecv(static_cast(fd), &buf, 1, &bytes, &flags, reinterpret_cast(as_ov(ext.get())), nullptr); + if (rc == 0) { + // Synchronous completion — not posted to IOCP due to + // FILE_SKIP_COMPLETION_PORT_ON_SUCCESS. + m_sync_completions.push_back(make_sync_completion(std::move(ext), bytes)); + return true; + } + if (WSAGetLastError() != WSA_IO_PENDING) { + return false; + } + ext.release(); // ownership transferred to IOCP + return true; + } + + bool async::submit_readv(net::fd_t fd, span bufs, uint64_t request_id) { + auto ext = std::make_unique(); + memset(ext->overlapped, 0, sizeof(ext->overlapped)); + ext->request_id = request_id; + ext->type = overlapped_ext::op_readv; + ext->accept_sock = 0; + + static_assert(sizeof(net::socket::iobuf) == sizeof(WSABUF)); + DWORD flags = 0; + DWORD bytes = 0; + int rc = WSARecv(static_cast(fd), reinterpret_cast(const_cast(bufs.data())), static_cast(bufs.size()), &bytes, &flags, reinterpret_cast(as_ov(ext.get())), nullptr); + if (rc == 0) { + m_sync_completions.push_back(make_sync_completion(std::move(ext), bytes)); + return true; + } + if (WSAGetLastError() != WSA_IO_PENDING) { + return false; + } + ext.release(); // ownership transferred to IOCP + return true; + } + + bool async::submit_write(net::fd_t fd, const void* buffer, size_t len, uint64_t request_id) { + auto ext = std::make_unique(); + memset(ext->overlapped, 0, sizeof(ext->overlapped)); + ext->request_id = request_id; + ext->type = overlapped_ext::op_write; + ext->accept_sock = 0; + + WSABUF buf; + buf.buf = const_cast(static_cast(buffer)); + buf.len = static_cast(len); + + DWORD bytes = 0; + int rc = WSASend(static_cast(fd), &buf, 1, &bytes, 0, reinterpret_cast(as_ov(ext.get())), nullptr); + if (rc == 0) { + // Synchronous completion — not posted to IOCP due to + // FILE_SKIP_COMPLETION_PORT_ON_SUCCESS. + m_sync_completions.push_back(make_sync_completion(std::move(ext), bytes)); + return true; + } + if (WSAGetLastError() != WSA_IO_PENDING) { + return false; + } + ext.release(); // ownership transferred to IOCP + return true; + } + + bool async::submit_writev(net::fd_t fd, span bufs, uint64_t request_id) { + auto ext = std::make_unique(); + memset(ext->overlapped, 0, sizeof(ext->overlapped)); + ext->request_id = request_id; + ext->type = overlapped_ext::op_writev; + ext->accept_sock = 0; + + static_assert(sizeof(net::socket::iobuf) == sizeof(WSABUF)); + DWORD bytes = 0; + int rc = WSASend(static_cast(fd), reinterpret_cast(const_cast(bufs.data())), static_cast(bufs.size()), &bytes, 0, reinterpret_cast(as_ov(ext.get())), nullptr); + if (rc == 0) { + m_sync_completions.push_back(make_sync_completion(std::move(ext), bytes)); + return true; + } + if (WSAGetLastError() != WSA_IO_PENDING) { + return false; + } + ext.release(); // ownership transferred to IOCP + return true; + } + + bool async::submit_accept(net::fd_t listen_fd, uint64_t request_id) { + // Query the address family of the listen socket so the accept socket matches. + net::endpoint ep; + int af = AF_INET; + if (net::socket::getsockname(listen_fd, ep)) { + if (ep.get_family() == net::family::inet6) { + af = AF_INET6; + } + } + SOCKET accept_sock = WSASocketW(af, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, WSA_FLAG_OVERLAPPED); + if (accept_sock == INVALID_SOCKET) { + return false; + } + // Associate the newly-created accept socket with the IOCP. + if (!associate(static_cast(accept_sock))) { + closesocket(accept_sock); + return false; + } + + auto ext = std::make_unique(); + memset(ext->overlapped, 0, sizeof(ext->overlapped)); + memset(ext->addr_buf, 0, sizeof(ext->addr_buf)); + ext->request_id = request_id; + ext->type = overlapped_ext::op_accept; + ext->accept_sock = static_cast(accept_sock); + ext->listen_sock = static_cast(static_cast(listen_fd)); + + if (!m_acceptex) { + closesocket(accept_sock); + return false; + } + // AcceptEx output buffer: local addr slot + remote addr slot. + // dwReceiveDataLength = 0 (no data prefix). + auto fn_acceptex = reinterpret_cast(m_acceptex); + DWORD bytes_received = 0; + BOOL ok = fn_acceptex(static_cast(listen_fd), accept_sock, ext->addr_buf, 0, overlapped_ext::kAddrSlotSize, overlapped_ext::kAddrSlotSize, &bytes_received, reinterpret_cast(as_ov(ext.get()))); + if (!ok && WSAGetLastError() != WSA_IO_PENDING) { + closesocket(accept_sock); + return false; + } + ext.release(); // ownership transferred to IOCP + return true; + } + + bool async::submit_connect(net::fd_t fd, const net::endpoint& ep, uint64_t request_id) { + auto ext = std::make_unique(); + memset(ext->overlapped, 0, sizeof(ext->overlapped)); + ext->request_id = request_id; + ext->type = overlapped_ext::op_connect; + ext->accept_sock = static_cast(static_cast(fd)); // reuse accept_sock to store connect fd + + // ConnectEx requires a pre-bound socket. + // Bind to the appropriate wildcard address matching the endpoint family. + if (ep.get_family() == net::family::inet6) { + sockaddr_in6 bind_addr {}; + bind_addr.sin6_family = AF_INET6; + bind(static_cast(fd), reinterpret_cast(&bind_addr), sizeof(bind_addr)); + } else { + sockaddr_in bind_addr {}; + bind_addr.sin_family = AF_INET; + bind_addr.sin_addr.s_addr = INADDR_ANY; + bind(static_cast(fd), reinterpret_cast(&bind_addr), sizeof(bind_addr)); + } + + if (!m_connectex) { + return false; + } + auto fn_connectex = reinterpret_cast(m_connectex); + BOOL ok = fn_connectex(static_cast(fd), ep.addr(), ep.addrlen(), nullptr, 0, nullptr, reinterpret_cast(as_ov(ext.get()))); + if (!ok && WSAGetLastError() != WSA_IO_PENDING) { + return false; + } + ext.release(); // ownership transferred to IOCP + return true; + } + + bool async::submit_file_read(file_handle::value_type fd, void* buffer, size_t len, int64_t offset, uint64_t request_id) { + auto ext = std::make_unique(); + memset(ext->overlapped, 0, sizeof(ext->overlapped)); + ext->request_id = request_id; + ext->type = overlapped_ext::op_file_read; + ext->accept_sock = 0; + + auto* ov = as_ov(ext.get()); + ov->Offset = static_cast(offset & 0xFFFFFFFF); + ov->OffsetHigh = static_cast(offset >> 32); + + BOOL ok = ReadFile(static_cast(fd), buffer, static_cast(len), nullptr, ov); + if (!ok && GetLastError() != ERROR_IO_PENDING) { + return false; + } + ext.release(); // ownership transferred to IOCP + return true; + } + + bool async::submit_file_write(file_handle::value_type fd, const void* buffer, size_t len, int64_t offset, uint64_t request_id) { + auto ext = std::make_unique(); + memset(ext->overlapped, 0, sizeof(ext->overlapped)); + ext->request_id = request_id; + ext->type = overlapped_ext::op_file_write; + ext->accept_sock = 0; + + auto* ov = as_ov(ext.get()); + ov->Offset = static_cast(offset & 0xFFFFFFFF); + ov->OffsetHigh = static_cast(offset >> 32); + + BOOL ok = WriteFile(static_cast(fd), buffer, static_cast(len), nullptr, ov); + if (!ok && GetLastError() != ERROR_IO_PENDING) { + return false; + } + ext.release(); // ownership transferred to IOCP + return true; + } + + bool async::submit_poll(net::fd_t fd, uint64_t request_id) { + // 使用 zero-byte WSARecv 实现 poll:不消费任何数据, + // 当 fd 可读时 IOCP 会收到完成通知。 + auto ext = std::make_unique(); + memset(ext->overlapped, 0, sizeof(ext->overlapped)); + ext->request_id = request_id; + ext->type = overlapped_ext::op_poll; + ext->accept_sock = 0; + + WSABUF buf; + buf.buf = nullptr; + buf.len = 0; + + DWORD flags = 0; + DWORD bytes = 0; + int rc = WSARecv(static_cast(fd), &buf, 1, &bytes, &flags, reinterpret_cast(as_ov(ext.get())), nullptr); + if (rc == 0) { + m_sync_completions.push_back(make_sync_completion(std::move(ext), bytes)); + return true; + } + if (WSAGetLastError() != WSA_IO_PENDING) { + return false; + } + ext.release(); // ownership transferred to IOCP + return true; + } + + static io_completion make_completion(OVERLAPPED_ENTRY& entry) { + std::unique_ptr ext( + reinterpret_cast(entry.lpOverlapped) + ); + auto* ov = reinterpret_cast(ext->overlapped); + io_completion c; + c.request_id = ext->request_id; + c.bytes_transferred = static_cast(entry.dwNumberOfBytesTransferred); + c.error_code = 0; + static constexpr async_op op_map[] = { + async_op::read, + async_op::readv, + async_op::write, + async_op::writev, + async_op::accept, + async_op::connect, + async_op::file_read, + async_op::file_write, + async_op::fd_poll, + }; + c.op = op_map[static_cast(ext->type)]; + + // OVERLAPPED.Internal holds the NTSTATUS of the completed I/O. + // A non-zero value means the operation failed. + DWORD err = 0; + if (ov->Internal != 0) { + // Convert NTSTATUS to Win32 error code. + err = static_cast(ov->Internal); + if ((err & 0xC0000000) == 0xC0000000) { + // Try to get a proper Win32 error via RtlNtStatusToDosError. + // For simplicity, mask off the NTSTATUS facility bits; + // many NTSTATUS codes map directly to Win32 errors. + err = RtlNtStatusToDosError(static_cast(ov->Internal)); + } + } + + if (ext->type == async::overlapped_ext::op_accept) { + if (err == 0) { + SOCKET as = static_cast(ext->accept_sock); + SOCKET ls = static_cast(ext->listen_sock); + // Inherit listen socket properties so getsockname/getpeername/shutdown work. + if (setsockopt(as, SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, reinterpret_cast(&ls), sizeof(ls)) != 0) { + closesocket(as); + c.status = async_status::error; + c.error_code = static_cast(WSAGetLastError()); + return c; + } + // Set non-blocking to match bee.socket.accept behaviour. + unsigned long nonblock = 1; + if (ioctlsocket(as, FIONBIO, &nonblock) != 0) { + closesocket(as); + c.status = async_status::error; + c.error_code = static_cast(WSAGetLastError()); + return c; + } + c.status = async_status::success; + c.bytes_transferred = ext->accept_sock; // new fd + } else { + closesocket(static_cast(ext->accept_sock)); + c.status = async_status::error; + c.error_code = static_cast(err); + } + } else if (ext->type == async::overlapped_ext::op_connect) { + if (err == 0) { + // Update the socket context so getsockname/shutdown/etc. work correctly. + if (setsockopt(static_cast(ext->accept_sock), SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, nullptr, 0) != 0) { + c.status = async_status::error; + c.error_code = static_cast(WSAGetLastError()); + return c; + } + c.status = async_status::success; + } else { + c.status = async_status::error; + c.error_code = static_cast(err); + } + } else if (err != 0) { + c.status = async_status::error; + c.error_code = static_cast(err); + } else if (entry.dwNumberOfBytesTransferred == 0 && + ext->type != async::overlapped_ext::op_file_read && + ext->type != async::overlapped_ext::op_file_write && + ext->type != async::overlapped_ext::op_poll) { + c.status = async_status::close; + } else { + c.status = async_status::success; + } + + return c; + } + + int async::poll(const span& completions) { + int count = 0; + // Drain synchronous completions first. + while (!m_sync_completions.empty() && count < static_cast(completions.size())) { + completions[count++] = m_sync_completions.front(); + m_sync_completions.pop_front(); + } + if (count >= static_cast(completions.size())) return count; + + static_assert(sizeof(OVERLAPPED_ENTRY) > 0); + const ULONG max = static_cast(completions.size()) - static_cast(count); + dynarray entries(max); + ULONG n = 0; + BOOL ok = GetQueuedCompletionStatusEx(static_cast(m_iocp), entries.data(), max, &n, 0, FALSE); + if (ok && n > 0) { + for (ULONG i = 0; i < n; ++i) { + completions[count++] = make_completion(entries[i]); + } + } + return count; + } + + int async::wait(const span& completions, int timeout) { + int count = 0; + // Drain synchronous completions first — if any exist, return immediately + // without blocking. + while (!m_sync_completions.empty() && count < static_cast(completions.size())) { + completions[count++] = m_sync_completions.front(); + m_sync_completions.pop_front(); + } + if (count > 0) return count; + + const ULONG max = static_cast(completions.size()); + dynarray entries(max); + DWORD ms = (timeout < 0) ? INFINITE : static_cast(timeout); + ULONG n = 0; + BOOL ok = GetQueuedCompletionStatusEx(static_cast(m_iocp), entries.data(), max, &n, ms, FALSE); + if (!ok || n == 0) return 0; + for (ULONG i = 0; i < n; ++i) { + completions[i] = make_completion(entries[i]); + } + return static_cast(n); + } + + void async::stop() { + if (!m_iocp) return; + + // Drain all completion notifications so every overlapped_ext is freed. + // After CancelIoEx the cancelled operations will complete with + // ERROR_OPERATION_ABORTED; we just need to delete each ext. + static constexpr ULONG kDrainBatch = 64; + OVERLAPPED_ENTRY entries[kDrainBatch]; + for (;;) { + ULONG count = 0; + BOOL ok = GetQueuedCompletionStatusEx(static_cast(m_iocp), entries, kDrainBatch, &count, 0, FALSE); + if (!ok || count == 0) break; + for (ULONG i = 0; i < count; ++i) { + std::unique_ptr ext( + reinterpret_cast(entries[i].lpOverlapped) + ); + if (ext->type == async::overlapped_ext::op_accept) { + closesocket(static_cast(ext->accept_sock)); + } + } + } + + CloseHandle(static_cast(m_iocp)); + m_iocp = nullptr; + } + +} // namespace bee::async diff --git a/bee/async/async_win.h b/bee/async/async_win.h new file mode 100644 index 00000000..be31e8da --- /dev/null +++ b/bee/async/async_win.h @@ -0,0 +1,78 @@ +#pragma once + +#if defined(_WIN32) +// clang-format off +# include +// clang-format on +# include +#endif + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace bee::net { + struct endpoint; +} + +namespace bee::async { + + class async { + public: + async(); + ~async(); + + bool submit_read(net::fd_t fd, void* buffer, size_t len, uint64_t request_id); + bool submit_readv(net::fd_t fd, span bufs, uint64_t request_id); + bool submit_write(net::fd_t fd, const void* buffer, size_t len, uint64_t request_id); + bool submit_writev(net::fd_t fd, span bufs, uint64_t request_id); + bool submit_accept(net::fd_t listen_fd, uint64_t request_id); + bool submit_connect(net::fd_t fd, const net::endpoint& ep, uint64_t request_id); + bool submit_file_read(file_handle::value_type fd, void* buffer, size_t len, int64_t offset, uint64_t request_id); + bool submit_file_write(file_handle::value_type fd, const void* buffer, size_t len, int64_t offset, uint64_t request_id); + bool submit_poll(net::fd_t fd, uint64_t request_id); + int poll(const span& completions); + int wait(const span& completions, int timeout); + void stop(); + bool associate(net::fd_t fd); + std::pair associate_file(file_handle::value_type fd); + void cancel(net::fd_t fd); + + struct overlapped_ext { + // OVERLAPPED must be first so we can cast between the two. + unsigned char overlapped[32]; // sizeof(OVERLAPPED) == 32 on x64 (static_assert below) + uint64_t request_id; + uintptr_t accept_sock; // used only for op_accept + uintptr_t listen_sock; // used only for op_accept (for SO_UPDATE_ACCEPT_CONTEXT) + enum op_type : uint8_t { + op_read, + op_readv, + op_write, + op_writev, + op_accept, + op_connect, + op_file_read, + op_file_write, + op_poll, + } type; + // Output buffer for AcceptEx: (sizeof(sockaddr_storage)+16) * 2 bytes + // sizeof(sockaddr_storage)==128, so 144*2 = 288 bytes + static constexpr DWORD kAddrSlotSize = sizeof(sockaddr_storage) + 16; + char addr_buf[kAddrSlotSize * 2]; + }; + + private: + void* m_iocp; // HANDLE + void* m_connectex; // LPFN_CONNECTEX + void* m_acceptex; // LPFN_ACCEPTEX + std::deque m_sync_completions; + }; + +} // namespace bee::async diff --git a/bee/async/ring_buf.h b/bee/async/ring_buf.h new file mode 100644 index 00000000..54f52318 --- /dev/null +++ b/bee/async/ring_buf.h @@ -0,0 +1,121 @@ +#pragma once + +#include +#include +#include +#include + +namespace bee::async { + + // Per-stream receive ring buffer. + // + // cap is always a power of two so that index masking (& (cap-1)) works. + // head and tail are monotonically increasing absolute offsets; the actual + // position in data[] is obtained via (offset & (cap-1)). + // + // Invariants: + // 0 <= size() <= cap + // write_ptr() points into the free region [tail, head+cap) + // + // Lifetime: held as a Lua userdata object; GC handles deallocation. + // Thread safety: none -- accessed from a single Lua thread. + struct ring_buf { + char* data = nullptr; + size_t cap = 0; // capacity, always a power of two + size_t head = 0; // consumer cursor (absolute, never wraps) + size_t tail = 0; // producer cursor (absolute, never wraps) + + // --------------- capacity / state queries --------------- + + size_t size() const noexcept { return tail - head; } + size_t free_cap() const noexcept { return cap - size(); } + bool empty() const noexcept { return head == tail; } + + // --------------- producer side (used by submit_stream_read) --------------- + + // Pointer to the start of the contiguous free region. + char* write_ptr() noexcept { + return data + (tail & (cap - 1)); + } + + // Length of the contiguous free region. May be less than free_cap() when + // the free space wraps around the end of the buffer. + size_t write_len() const noexcept { + if (free_cap() == 0) return 0; + size_t wrap_end = cap - (tail & (cap - 1)); + return (std::min)(wrap_end, free_cap()); + } + + // Advance the producer cursor after a successful read completion. + void commit(size_t n) noexcept { + assert(n <= free_cap()); + tail += n; + } + + // --------------- consumer side (used by read) --------------- + + // Search for the first occurrence of the byte sequence [sep, sep+seplen) in + // the buffered data. Returns the number of bytes from head up to and + // including the last byte of the found sequence, or 0 if not found. + // seplen must be >= 1. + size_t find(const char* sep, size_t seplen) const noexcept { + size_t n = size(); + if (n < seplen) return 0; + size_t limit = n - seplen + 1; + for (size_t i = 0; i < limit; ++i) { + bool match = true; + for (size_t j = 0; j < seplen; ++j) { + if (data[(head + i + j) & (cap - 1)] != sep[j]) { + match = false; + break; + } + } + if (match) return i + seplen; + } + return 0; + } + + // Copy exactly n bytes from the ring into dst and advance head. + // Returns false (without modifying head) if fewer than n bytes are available. + bool consume(char* dst, size_t n) noexcept { + if (size() < n) return false; + size_t idx = head & (cap - 1); + size_t first = (std::min)(n, cap - idx); + memcpy(dst, data + idx, first); + if (first < n) { + memcpy(dst + first, data, n - first); + } + head += n; + return true; + } + + // --------------- lifecycle --------------- + + // Round sz up to the nearest power of two (minimum 16). + static size_t round_up_pow2(size_t sz) noexcept { + if (sz < 16) sz = 16; + sz--; + sz |= sz >> 1; + sz |= sz >> 2; + sz |= sz >> 4; + sz |= sz >> 8; + sz |= sz >> 16; + if constexpr (sizeof(size_t) > 4) { + sz |= sz >> 32; + } + return sz + 1; + } + + explicit ring_buf(size_t bufsize) { + cap = round_up_pow2(bufsize); + data = new char[cap]; + } + + ring_buf() = default; + + ~ring_buf() noexcept { + delete[] data; + } + }; + +} // namespace bee::async diff --git a/bee/async/write_buf.h b/bee/async/write_buf.h new file mode 100644 index 00000000..5dd58eef --- /dev/null +++ b/bee/async/write_buf.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace bee::async { + + // Per-stream write buffer. + // + // Maintains a queue of pending string entries to be sent over a socket. + // The C layer submits all queued entries at once via submit_write using + // submit_writev with multiple iovecs, handling partial writes transparently + // (auto-drain). When the queue drains to empty a single Lua-visible + // completion is produced. + // + // hwm (high-water mark): write() returns true when buffered >= hwm, + // signalling the Lua caller to back-pressure (block until stream_on_write). + // + // Lifetime: held as a Lua userdata object; GC handles deallocation. + // Thread safety: none -- accessed from a single Lua thread. + struct write_buf { + struct entry { + const char* data; // pointer into Lua string (kept alive by str_ref) + size_t len; + size_t offset; // bytes already sent (partial write progress) + int str_ref; // luaL_ref keeping the Lua string alive + }; + + std::deque q; + size_t buffered = 0; // total queued bytes + size_t hwm = 0; // high-water mark threshold + bool in_flight = false; // true while a writev is outstanding + // Fields valid only while in_flight == true: + net::fd_t fd = net::fd_t {}; // socket being written to + uint64_t lua_reqid = 0; // Lua-assigned reqid for final completion + // iov snapshot for the current in-flight writev (entries may span multiple q items). + // Built by wb_submit_all; must remain valid until the completion callback fires. + dynarray iov_cache; + }; + +} // namespace bee::async diff --git a/bee/net/socket.cpp b/bee/net/socket.cpp index 5324b350..01721b16 100644 --- a/bee/net/socket.cpp +++ b/bee/net/socket.cpp @@ -355,6 +355,46 @@ namespace bee::net::socket { return recv_status::success; } + recv_status recvv(fd_t s, int& rc, span bufs) noexcept { + if (bufs.empty()) { + rc = 0; + return recv_status::success; + } + if (bufs.size() == 1) { +#if defined(_WIN32) + return recv(s, rc, bufs[0].buf, (int)bufs[0].len); +#else + return recv(s, rc, (char*)bufs[0].iov_base, (int)bufs[0].iov_len); +#endif + } +#if defined(_WIN32) + static_assert(sizeof(iobuf) == sizeof(WSABUF)); + DWORD received = 0; + DWORD flags = 0; + const int ok = ::WSARecv((SOCKET)s, reinterpret_cast(bufs.data()), (DWORD)bufs.size(), &received, &flags, NULL, NULL); + if (ok == SOCKET_ERROR) { + const int err = ::WSAGetLastError(); + if (err == WSAEINPROGRESS || err == WSAEWOULDBLOCK) return recv_status::wait; + return recv_status::failed; + } + rc = (int)received; + return rc == 0 ? recv_status::close : recv_status::success; +#else + static_assert(sizeof(iobuf) == sizeof(struct iovec)); + struct msghdr msg = {}; + msg.msg_iov = reinterpret_cast(bufs.data()); + msg.msg_iovlen = (int)bufs.size(); + rc = (int)::recvmsg(s, &msg, 0); + if (rc == 0) { + return recv_status::close; + } + if (rc < 0) { + return wait_finish() ? recv_status::wait : recv_status::failed; + } + return recv_status::success; +#endif + } + status send(fd_t s, int& rc, const char* buf, int len) noexcept { int flags = 0; #ifdef MSG_NOSIGNAL diff --git a/bee/net/socket.h b/bee/net/socket.h index fed5528d..313eb0e2 100644 --- a/bee/net/socket.h +++ b/bee/net/socket.h @@ -84,6 +84,7 @@ namespace bee::net::socket { status connect(fd_t s, const endpoint& ep) noexcept; status accept(fd_t s, fd_t& newfd, fd_flags flags = fd_flags::nonblock) noexcept; recv_status recv(fd_t s, int& rc, char* buf, int len) noexcept; + recv_status recvv(fd_t s, int& rc, span bufs) noexcept; status send(fd_t s, int& rc, const char* buf, int len) noexcept; status sendv(fd_t s, int& rc, span bufs) noexcept; status recvfrom(fd_t s, int& rc, endpoint& ep, char* buf, int len) noexcept; diff --git a/benchmark/bench.lua b/benchmark/bench.lua new file mode 100644 index 00000000..f11cb0b5 --- /dev/null +++ b/benchmark/bench.lua @@ -0,0 +1,194 @@ +-- bench.lua: 对比 epoll 后端(net_epoll.lua)和 async 后端(net_async.lua)的性能 +-- +-- 用法(从 benchmark/ 目录运行): +-- lua bench.lua [并发数] [每连接消息数] [消息大小] +-- +-- 始终同时运行 epoll/async × rtt/pipeline 四种组合并输出对比。 + +local time = require "bee.time" + +local conns = tonumber(arg and arg[1]) or 10 +local msgs = tonumber(arg and arg[2]) or 1000 +local msg_size = tonumber(arg and arg[3]) or 4096 + +local function load_backend(name) + package.loaded["net_epoll"] = nil + package.loaded["net_async"] = nil + package.loaded["ltask"] = nil + + local net = require(name == "epoll" and "net_epoll" or "net_async") + local ltask = require "ltask" + + local socket = require "bee.socket" + local probe = assert(socket.create "tcp") + assert(probe:bind("127.0.0.1", 0)) + local _, port = probe:info("socket"):value() + probe:close() + + return net, ltask, port +end + +local function run_loop(net, ltask, finish_token) + local exit_loop = false + net.fork(function () + ltask.wait(finish_token) + exit_loop = true + end) + while not exit_loop do + net.schedule() -- 耗尽所有可运行任务 + net.wait(0) -- I/O 完成事件触发新 wakeup + -- 若 I/O 产生了新任务,继续调度;否则阻塞等待 + if not net.schedule() and not exit_loop then + net.wait(1) + end + end + net.wait(0) + net.schedule() +end + +-- 公共框架:启动 echo 服务端,并发运行 client_fn,返回统计 +local function run_bench(backend_name, conns_count, client_fn) + local net, ltask, port = load_backend(backend_name) + local done = 0 + local total_msg = 0 + local total_lat = 0 + local errors = 0 + local token = {} + + local function finish(ok_msgs, lat_sum) + total_msg = total_msg + ok_msgs + total_lat = total_lat + (lat_sum or 0) + done = done + 1 + if done >= conns_count then ltask.wakeup(token) end + end + + net.fork(function () + local server = assert(net.listen("tcp", "127.0.0.1", port)) + while true do + local cli = server:accept() + if not cli then break end + net.fork(function () + while true do + local data = cli:recv() + if not data or data == "" then break end + net.fork(function () cli:send(data) end) + end + cli:close() + end) + end + end) + + for _ = 1, conns_count do + net.fork(function () + local fd = net.connect("tcp", "127.0.0.1", port) + if not fd then + errors = errors + 1 + finish(0) + return + end + local ok, lat = client_fn(net, fd) + errors = errors + (ok < 0 and -ok or 0) + fd:close() + finish(ok > 0 and ok or 0, lat) + end) + end + + local t0 = time.monotonic() + run_loop(net, ltask, token) + local elapsed = time.monotonic() - t0 + + return { + total_msg = total_msg, + bytes = total_msg * msg_size * 2, + elapsed = elapsed, + errors = errors, + lat_avg = total_msg > 0 and (total_lat / total_msg * 1000) or 0, + } +end + +local function run_rtt(backend_name, conns_count, msgs_count, payload_size) + local payload = string.rep("x", payload_size) + return run_bench(backend_name, conns_count, function (_, fd) + local ok, lat_sum = 0, 0 + for _ = 1, msgs_count do + local t0 = time.monotonic() + if not fd:send(payload) then return -1, 0 end + local echo = fd:recv(payload_size) + if echo and #echo == payload_size then + ok = ok + 1 + lat_sum = lat_sum + (time.monotonic() - t0) + else + return -1, 0 + end + end + return ok, lat_sum + end) +end + +local function run_pipeline(backend_name, conns_count, msgs_count, payload_size) + local payload = string.rep("x", payload_size) + return run_bench(backend_name, conns_count, function (net, fd) + net.fork(function () + for _ = 1, msgs_count do fd:send(payload) end + end) + local received, expected = 0, msgs_count * payload_size + while received < expected do + local r = fd:recv() + if not r then return -1, 0 end + received = received + #r + end + return msgs_count, nil + end) +end + +local function print_result(label, r) + local s = math.max(r.elapsed / 1000, 0.001) + local lat_str = r.lat_avg > 0 and string.format("%7.1fµs", r.lat_avg) or " n/a " + print(string.format( + "%-16s | msgs=%7d | time=%7.1fms | %9.0f msg/s | %6.2f MB/s | lat=%s | err=%d", + label, r.total_msg, r.elapsed, + r.total_msg / s, r.bytes / s / 1024 / 1024, + lat_str, r.errors + )) +end + +local backends = { "epoll", "async" } +local modes = { "rtt", "pipeline" } +local runners = { rtt = run_rtt, pipeline = run_pipeline } + +print(string.format( + "=== bee.lua 后端性能对比 | 并发=%d | 每连接消息=%d | 消息大小=%d 字节 | 总消息=%d ===", + conns, msgs, msg_size, conns * msgs +)) +print(string.format("%-16s | %-7s | %-10s | %-14s | %-9s | %-10s | %s", + "后端/模式", "总消息", "耗时", "吞吐(msg/s)", "吞吐(MB/s)", "平均RTT", "错误")) +print(string.rep("-", 100)) + +local results = {} +for _, mode in ipairs(modes) do + results[mode] = {} + for _, b in ipairs(backends) do + local label = b.."/"..mode + local ok, r = pcall(runners[mode], b, conns, msgs, msg_size) + if ok then + results[mode][b] = r + print_result(label, r) + else + print(string.format("%-16s | 运行失败: %s", label, tostring(r))) + end + end +end + +print(string.rep("-", 100)) +for _, mode in ipairs(modes) do + local e, a = results[mode]["epoll"], results[mode]["async"] + if e and a then + local es = math.max(e.elapsed / 1000, 0.001) + local as_ = math.max(a.elapsed / 1000, 0.001) + local tput = (a.total_msg / as_) / (e.total_msg / es) + local lat = (e.lat_avg > 0 and a.lat_avg > 0) + and string.format(" | 延迟 %.2fx", e.lat_avg / a.lat_avg) or "" + print(string.format("[%-8s] async 相对 epoll: 吞吐 %.2fx%s (>1 表示 async 更好)", + mode, tput, lat)) + end +end diff --git a/benchmark/ltask.lua b/benchmark/ltask.lua new file mode 100644 index 00000000..3dd37c8c --- /dev/null +++ b/benchmark/ltask.lua @@ -0,0 +1,146 @@ +local task = {} + +local MESSAGE_OK = 0 +local MESSAGE_ERROR = 1 + +local running_thread +local wakeup_queue = {} +local message_queue = {} +local thread_address = {} +local waiting = {} +local handler = {} + +local function wakeup_thread(...) + wakeup_queue[#wakeup_queue + 1] = { ... } +end + +local function handle_message(from, command, ...) + local s = handler[command] + if not s then + error("Unknown message : " .. command) + return + end + wakeup_thread(from, MESSAGE_OK, s(...)) + thread_address[running_thread] = nil +end + +local function do_wakeup(co, ...) + running_thread = co + local ok, errobj = coroutine.resume(co, ...) + running_thread = nil + if ok then + return errobj + else + local from = thread_address[co] + thread_address[co] = nil + errobj = debug.traceback(co, errobj) + if from == nil then + print("Error:", tostring(errobj)) + else + wakeup_thread(from, MESSAGE_ERROR, errobj) + end + coroutine.close(co) + end +end + +local function do_message(from, ...) + local co = coroutine.create(handle_message) + thread_address[co] = from + do_wakeup(co, from, ...) +end + +local function response(type, ...) + if type == MESSAGE_OK then + return ... + else -- type == MESSAGE_ERROR + error(...) + end +end + +local function send_message(from, ...) + message_queue[#message_queue + 1] = { from, ... } +end + +local function no_response_() + while true do + local type, errobj = coroutine.yield() + if type == MESSAGE_ERROR then + print("Error:", tostring(errobj)) + end + end +end +local no_response_handler = coroutine.create(no_response_) +coroutine.resume(no_response_handler) + +function task.wait(token) + token = token + waiting[token] = running_thread + return response(coroutine.yield()) +end + +function task.wakeup(token, ...) + local co = waiting[token] + if co then + wakeup_thread(co, MESSAGE_OK, ...) + waiting[token] = nil + return true + end +end + +function task.interrupt(token, errobj) + local co = waiting[token] + if co then + errobj = debug.traceback(errobj) + wakeup_thread(co, MESSAGE_ERROR, errobj) + waiting[token] = nil + return true + end +end + +function task.fork(func, ...) + local co = coroutine.create(func) + wakeup_thread(co, ...) +end + +function task.yield() + wakeup_thread(running_thread) + coroutine.yield() +end + +function task.call(command, ...) + send_message(running_thread, command, ...) + return response(coroutine.yield()) +end + +function task.send(command, ...) + send_message(no_response_handler, command, ...) +end + +function task.dispatch(h) + for k, v in pairs(h) do + handler[k] = v + end +end + +function task.schedule() + local did_work = false + -- 耗尽所有可运行任务,避免每条消息触发一次 I/O poll + while #message_queue ~= 0 do + local s = table.remove(message_queue, 1) + do_message(table.unpack(s)) + did_work = true + end + while #wakeup_queue ~= 0 do + local s = table.remove(wakeup_queue, 1) + do_wakeup(table.unpack(s)) + did_work = true + -- wakeup 可能产生新 message,立即处理 + while #message_queue ~= 0 do + local s2 = table.remove(message_queue, 1) + do_message(table.unpack(s2)) + end + end + return did_work or nil +end + +return task diff --git a/benchmark/net_async.lua b/benchmark/net_async.lua new file mode 100644 index 00000000..0741a22d --- /dev/null +++ b/benchmark/net_async.lua @@ -0,0 +1,347 @@ +local ltask = require "ltask" +local socket = require "bee.socket" +local async = require "bee.async" + +local asfd = async.create(512) + +local SUCCESS = async.SUCCESS +local OP_READ = async.OP_READ +local OP_READV = async.OP_READV +local OP_WRITEV = async.OP_WRITEV +local OP_ACCEPT = async.OP_ACCEPT +local OP_CONNECT = async.OP_CONNECT + +local kRbSize = 64 * 1024 -- per-stream read ring buffer size +local kWbSize = 64 * 1024 -- write queue high-water-mark (bytes) + +local status = {} +local handle = {} + +local function create_handle(fd) + local h = handle[fd] + if h then + return h + end + h = #handle + 1 + handle[h] = fd + handle[fd] = h + return h +end + +local function close_stream(s) + if s.closed then + return + end + s.closed = true + if s.wait_read then + for i, token in ipairs(s.wait_read) do + ltask.wakeup(token) + s.wait_read[i] = nil + end + end + if s.wb then + for _, token in ipairs(s.wait_write) do + ltask.wakeup(token) + end + s.wait_write = {} + s.wb:close() + s.wb = nil + end + if s.wait_close then + for _, token in ipairs(s.wait_close) do + ltask.wakeup(token) + end + s.wait_close = nil + end + asfd:cancel(s.fd) + s.fd:close() +end + +local function stream_submit_read(s) + if s.closed or s.rb_in_flight then + return + end + if asfd:submit_read(s.rb, s.fd, s) then + s.rb_in_flight = true + end +end + +local function stream_on_read(s, bytes) + s.rb_in_flight = false + if not bytes or bytes == 0 then + close_stream(s) + return + end + while s.wait_read and #s.wait_read > 0 do + local token = s.wait_read[1] + local n = token[1] + local data = s.rb:read(n) + if not data then + break + end + table.remove(s.wait_read, 1) + ltask.wakeup(token, data) + end + if not s.closed then + stream_submit_read(s) + end +end + +-- Submit the writebuf to the async layer if not already in-flight. +local function stream_submit_write(s) + if s.closed or not s.wb or s.wb_in_flight then + return + end + if s.wb:buffered() == 0 then + return + end + if not asfd:submit_write(s.wb, s.fd, s) then + close_stream(s) + return + end + s.wb_in_flight = true +end + +-- Called when a submit_write completion arrives (queue fully drained). +local function stream_on_write(s) + s.wb_in_flight = false + -- Wake up all blocked writers. + local ww = s.wait_write + s.wait_write = {} + for _, token in ipairs(ww) do + ltask.wakeup(token) + end + -- If more data was enqueued while draining, resubmit. + if s.wb and s.wb:buffered() > 0 then + stream_submit_write(s) + end +end + +local function create_stream(newfd) + local s = { + fd = newfd, + rb = async.readbuf(kRbSize), + wb = async.writebuf(kWbSize), + wait_read = {}, + wait_write = {}, + rb_in_flight = false, + wb_in_flight = false, + closed = false, + } + status[newfd] = s + newfd:option("nodelay", 1) + stream_submit_read(s) + return create_handle(newfd) +end + +local S = {} + +function S.listen(protocol, ...) + local fd, err = socket.create(protocol) + if not fd then + return nil, err + end + asfd:associate(fd) + local ok, err = fd:bind(...) + if not ok then + return nil, err + end + ok, err = fd:listen() + if not ok then + return nil, err + end + local s = { + fd = fd, + closed = false, + is_listener = true, + } + status[fd] = s + return create_handle(fd) +end + +function S.connect(protocol, host, port) + if host and port then + local ep = socket.endpoint("hostname", host, port) + if not ep then + return nil, string.format("resolve hostname failed: %s:%d", host, port) + end + local _, _, family = ep:value() + if family == "inet6" then + if protocol == "tcp" then + protocol = "tcp6" + elseif protocol == "udp" then + protocol = "udp6" + end + end + end + local fd, err = socket.create(protocol) + if not fd then + return nil, err + end + asfd:associate(fd) + local wait_token = {} + local ok, cerr = asfd:submit_connect(fd, host, port, wait_token) + if not ok then + fd:close() + return nil, cerr + end + local result = ltask.wait(wait_token) + if not result then + fd:close() + return nil, "connect failed" + end + return create_stream(fd) +end + +function S.accept(h) + local fd = assert(handle[h], "Invalid fd.") + local s = status[fd] + assert(s.is_listener, "Not a listener.") + local wait_token = {} + asfd:submit_accept(fd, wait_token) + local newfd = ltask.wait(wait_token) + if not newfd then + return nil, "accept failed" + end + local ok, err = newfd:status() + if not ok then + newfd:close() + return nil, err + end + return create_stream(newfd) +end + +function S.send(h, data) + local fd = assert(handle[h], "Invalid fd.") + local s = status[fd] + if s.closed or not s.wb then + return + end + if data == "" then + return + end + -- Enqueue; wb:write returns true when buffered >= hwm (backpressure). + local full = s.wb:write(data) + -- Kick off a submit if nothing is currently in-flight. + if not s.wb_in_flight then + stream_submit_write(s) + end + -- Block if we pushed the buffer over hwm. + if full then + local token = {} + s.wait_write[#s.wait_write+1] = token + ltask.wait(token) + end + return true +end + +function S.recv(h, n) + local fd = assert(handle[h], "Invalid fd.") + local s = status[fd] + if not s.rb then + error "Read not allowed." + return + end + local data = s.rb:read(n) + if data then + stream_submit_read(s) + return data + end + if s.closed then + return + end + local token = { n } + s.wait_read[#s.wait_read+1] = token + stream_submit_read(s) + return ltask.wait(token) +end + +function S.close(h) + local fd = handle[h] + if fd then + local s = status[fd] + close_stream(s) + handle[h] = nil + handle[fd] = nil + status[fd] = nil + end +end + +function S.is_closed(h) + local fd = handle[h] + if fd then + return status[fd].closed + end + return true +end + +local fd_mt = {} +fd_mt.__index = fd_mt + +function fd_mt:accept(...) + local fd, err = ltask.call("accept", self.fd, ...) + if not fd then + return nil, err + end + return setmetatable({ fd = fd }, fd_mt) +end + +function fd_mt:send(...) + return ltask.call("send", self.fd, ...) +end + +function fd_mt:recv(...) + return ltask.call("recv", self.fd, ...) +end + +function fd_mt:close(...) + return ltask.call("close", self.fd, ...) +end + +function fd_mt:is_closed(...) + return ltask.call("is_closed", self.fd, ...) +end + +local net = {} + +function net.wait(timeout) + for op, udata, st, data in asfd:wait(timeout) do + if op == OP_READ or op == OP_READV then + stream_on_read(udata, st == SUCCESS and data or nil) + elseif op == OP_WRITEV then + if st == SUCCESS then + stream_on_write(udata) + else + close_stream(udata) + end + elseif op == OP_ACCEPT then + ltask.wakeup(udata, st == SUCCESS and data or nil) + elseif op == OP_CONNECT then + ltask.wakeup(udata, st == SUCCESS and true or nil) + end + end +end + +function net.listen(...) + local fd, err = ltask.call("listen", ...) + if not fd then + return nil, err + end + return setmetatable({ fd = fd }, fd_mt) +end + +function net.connect(...) + local fd, err = ltask.call("connect", ...) + if not fd then + return nil, err + end + return setmetatable({ fd = fd }, fd_mt) +end + +net.fork = ltask.fork +net.schedule = ltask.schedule +net.yield = ltask.yield + +ltask.dispatch(S) + +return net diff --git a/benchmark/net_epoll.lua b/benchmark/net_epoll.lua new file mode 100644 index 00000000..9d9b0b44 --- /dev/null +++ b/benchmark/net_epoll.lua @@ -0,0 +1,506 @@ +local ltask = require "ltask" +local socket = require "bee.socket" +local epoll = require "bee.epoll" + +local epfd = epoll.create(512) + +local EPOLLIN = epoll.EPOLLIN +local EPOLLOUT = epoll.EPOLLOUT +local EPOLLERR = epoll.EPOLLERR +local EPOLLHUP = epoll.EPOLLHUP + +local kMaxReadBufSize = 64 * 1024 + +local status = {} +local handle = {} + +local function fd_update(s) + local flags = 0 + if s.r then + flags = flags | EPOLLIN + end + if s.w then + flags = flags | EPOLLOUT + end + if flags ~= s.event_flags then + epfd:event_mod(s.fd, flags) + s.event_flags = flags + end +end + +local function fd_set_read(s) + if s.shutdown_r then + return + end + s.r = true + fd_update(s) +end + +local function fd_clr_read(s) + s.r = nil + fd_update(s) +end + +local function fd_set_write(s) + if s.shutdown_w then + return + end + s.w = true + fd_update(s) +end + +local function fd_clr_write(s) + s.w = nil + fd_update(s) +end + +local function fd_init(fd) + local s = status[fd] + local function on_event(e) + if e & (EPOLLERR | EPOLLHUP) ~= 0 then + e = e | EPOLLIN | EPOLLOUT + end + if e & EPOLLIN ~= 0 then + assert(not s.shutdown_r) + s:on_read() + end + if e & EPOLLOUT ~= 0 then + if not s.shutdown_w then + s:on_write() + end + end + end + epfd:event_add(fd, 0, on_event) +end + +local function create_handle(fd) + local h = handle[fd] + if h then + return h + end + h = #handle + 1 + handle[h] = fd + handle[fd] = h + return h +end + +local function close(s) + local fd = s.fd + epfd:event_del(fd) + fd:close() + assert(s.shutdown_r) + assert(s.shutdown_w) + if s.wait_read then + assert(#s.wait_read == 0) + end + if s.wait_write then + assert(#s.wait_write == 0) + end + if s.wait_close then + for _, token in ipairs(s.wait_close) do + ltask.wakeup(token) + end + end +end + +local function close_write(s) + if s.shutdown_r and s.shutdown_w then + return + end + if not s.shutdown_w then + s.shutdown_w = true + fd_clr_write(s) + end + if s.shutdown_r then + fd_clr_read(s) + close(s) + end +end + +local function close_read(s) + if s.shutdown_r and s.shutdown_w then + return + end + if not s.shutdown_r then + s.shutdown_r = true + fd_clr_read(s) + if s.wait_read then + for i, token in ipairs(s.wait_read) do + ltask.wakeup(token) + s.wait_read[i] = nil + end + end + end + if s.shutdown_w then + close(s) + elseif not s.wait_write or #s.wait_write == 0 then + s.shutdown_w = true + fd_clr_write(s) + close(s) + end +end + +local function stream_dispatch_read(s) + while #s.wait_read > 0 do + local token = s.wait_read[1] + local n = token[1] + if n == nil then + ltask.wakeup(token, s.readbuf) + s.readbuf = "" + table.remove(s.wait_read, 1) + else + if n > #s.readbuf then + break + end + ltask.wakeup(token, s.readbuf:sub(1, n)) + s.readbuf = s.readbuf:sub(n + 1) + table.remove(s.wait_read, 1) + end + end + if #s.readbuf > kMaxReadBufSize then + fd_clr_read(s) + end +end + +local function stream_on_read(s) + -- 循环读直到 EAGAIN,充分消费内核缓冲区 + local parts + while true do + local data = s.fd:recv() + if data == nil then + -- EOF / 连接关闭 + if parts then + s.readbuf = s.readbuf .. table.concat(parts) + parts = nil + end + stream_dispatch_read(s) + close_read(s) + return + elseif data == false then + -- EAGAIN,本轮读完 + break + else + if parts then + parts[#parts + 1] = data + else + parts = { data } + end + end + end + if parts then + s.readbuf = s.readbuf .. table.concat(parts) + end + stream_dispatch_read(s) +end + +local function stream_on_write(s) + -- 首次 EPOLLOUT 意味着 connect 完成 + s.connected = true + while #s.wait_write > 0 do + local data = s.wait_write[1] + local n, err = s.fd:send(data[1]) + if n == nil then + for i, token in ipairs(s.wait_write) do + ltask.interrupt(token, err or "Write close.") + s.wait_write[i] = nil + end + close_write(s) + return + elseif n == false then + return + else + if n == #data[1] then + local token = table.remove(s.wait_write, 1) + ltask.wakeup(token, n) + if #s.wait_write == 0 then + fd_clr_write(s) + return + end + else + data[1] = data[1]:sub(n + 1) + return + end + end + end +end + +local function create_stream(newfd, connected) + local s = { + fd = newfd, + readbuf = "", + wait_read = {}, + wait_write = {}, + shutdown_r = false, + shutdown_w = false, + r = false, + w = false, + event_flags = 0, + connected = connected, + on_read = stream_on_read, + on_write = stream_on_write, + } + status[newfd] = s + newfd:option("nodelay", 1) + fd_init(newfd) + fd_set_read(s) + return create_handle(newfd) +end + +local S = {} + +function S.listen(protocol, ...) + local fd, err = socket.create(protocol) + if not fd then + return nil, err + end + local ok, err = fd:bind(...) + if not ok then + return nil, err + end + ok, err = fd:listen() + if not ok then + return nil, err + end + status[fd] = { + fd = fd, + shutdown_r = false, + shutdown_w = true, + r = false, + w = false, + event_flags = 0, + } + fd_init(fd) + return create_handle(fd) +end + +function S.connect(protocol, host, port) + -- 如果传入了 host 和 port,先通过 DNS 解析获取 endpoint + -- 然后根据解析结果的地址族自动选择正确的协议(tcp/tcp6) + local ep + if host and port then + ep = socket.endpoint("hostname", host, port) + if not ep then + return nil, string.format("resolve hostname failed: %s:%d", host, port) + end + -- 根据 endpoint 的地址族选择正确的协议 + local _, _, family = ep:value() + if family == "inet6" then + if protocol == "tcp" then + protocol = "tcp6" + elseif protocol == "udp" then + protocol = "udp6" + end + end + end + local fd, err = socket.create(protocol) + if not fd then + return nil, err + end + local r, err + if ep then + r, err = fd:connect(ep) + else + r, err = fd:connect(host, port) + end + if r == nil then + return nil, err + end + -- r == true:连接立即成功(loopback 可能);r == false:EINPROGRESS + return create_stream(fd, r == true) +end + +function S.accept(h) + local fd = assert(handle[h], "Invalid fd.") + local s = status[fd] + s.on_read = ltask.wakeup + fd_set_read(s) + ltask.wait(s) + local newfd, err = fd:accept() + if not newfd then + return nil, err + end + local ok, err = newfd:status() + if not ok then + return nil, err + end + return create_stream(newfd, true) +end + +function S.send(h, data) + local fd = assert(handle[h], "Invalid fd.") + local s = status[fd] + if not s.wait_write then + error "Write not allowed." + return + end + if s.shutdown_w then + return + end + if data == "" then + return 0 + end + -- 队列非空或连接尚未就绪时直接排队,保证有序 + if #s.wait_write > 0 or not s.connected then + local token = { data } + s.wait_write[#s.wait_write + 1] = token + fd_set_write(s) + return ltask.wait(token) + end + -- 乐观写:直接尝试发送,避免不必要的 epoll 往返 + local n, err = s.fd:send(data) + if n == nil then + -- 连接出错 + close_write(s) + return nil, err + elseif n == false then + -- EAGAIN:内核缓冲区满,挂起等待 EPOLLOUT + local token = { data } + s.wait_write[#s.wait_write + 1] = token + fd_set_write(s) + return ltask.wait(token) + elseif n < #data then + -- 部分写入:剩余数据入队等 EPOLLOUT + local token = { data:sub(n + 1) } + s.wait_write[#s.wait_write + 1] = token + fd_set_write(s) + return ltask.wait(token) + else + -- 全部写完,直接返回 + return n + end +end + +function S.recv(h, n) + local fd = assert(handle[h], "Invalid fd.") + local s = status[fd] + if not s.readbuf then + error "Read not allowed." + return + end + if s.shutdown_r then + if not n then + if s.readbuf == "" then + return + end + else + if n > #s.readbuf then + return + end + end + end + local sz = #s.readbuf + if not n then + if sz == 0 then + local token = { + } + s.wait_read[#s.wait_read + 1] = token + return ltask.wait(token) + end + local ret = s.readbuf + if sz > kMaxReadBufSize then + fd_set_read(s) + end + s.readbuf = "" + return ret + else + if n <= sz then + local ret = s.readbuf:sub(1, n) + if sz > kMaxReadBufSize and sz - n <= kMaxReadBufSize then + fd_set_read(s) + end + s.readbuf = s.readbuf:sub(n + 1) + return ret + else + local token = { n } + s.wait_read[#s.wait_read + 1] = token + return ltask.wait(token) + end + end +end + +function S.close(h) + local fd = handle[h] + if fd then + local s = status[fd] + close_read(s) + if not s.shutdown_w then + local token = {} + if s.wait_close then + s.wait_close[#s.wait_close + 1] = token + else + s.wait_close = { token } + end + ltask.wait(token) + end + handle[h] = nil + handle[fd] = nil + status[fd] = nil + end +end + +function S.is_closed(h) + local fd = handle[h] + if fd then + local s = status[fd] + return s.shutdown_w and s.shutdown_r + end +end + +local fd_mt = {} +fd_mt.__index = fd_mt + +function fd_mt:accept(...) + local fd, err = ltask.call("accept", self.fd, ...) + if not fd then + return nil, err + end + return setmetatable({ fd = fd }, fd_mt) +end + +function fd_mt:send(...) + return ltask.call("send", self.fd, ...) +end + +function fd_mt:recv(...) + return ltask.call("recv", self.fd, ...) +end + +function fd_mt:close(...) + return ltask.call("close", self.fd, ...) +end + +function fd_mt:is_closed(...) + return ltask.call("is_closed", self.fd, ...) +end + +local net = {} + +function net.wait(timeout) + for f, event in epfd:wait(timeout) do + f(event) + end +end + +function net.listen(...) + local fd, err = ltask.call("listen", ...) + if not fd then + return nil, err + end + return setmetatable({ fd = fd }, fd_mt) +end + +function net.connect(...) + local fd, err = ltask.call("connect", ...) + if not fd then + return nil, err + end + return setmetatable({ fd = fd }, fd_mt) +end + +net.fork = ltask.fork +net.schedule = ltask.schedule +net.yield = ltask.yield + +ltask.dispatch(S) + +return net diff --git a/binding/lua_async.cpp b/binding/lua_async.cpp new file mode 100644 index 00000000..334c5d51 --- /dev/null +++ b/binding/lua_async.cpp @@ -0,0 +1,792 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace bee::lua_socket { + net::fd_t& newfd(lua_State* L, net::fd_t fd); + net::fd_t& checkfd(lua_State* L, int idx); + net::endpoint& new_endpoint(lua_State* L); + net::endpoint& to_endpoint(lua_State* L, int idx, net::endpoint& ep); +} + +namespace bee::lua_async { + + struct lua_async { + std::unique_ptr handle; + luaref refs = nullptr; + int i = 0; + int n = 0; + dynarray completions; + lua_async(size_t max_completions) + : completions(max_completions) {} + ~lua_async() { + if (refs) luaref_close(refs); + } + }; + + static file_handle::value_type tofilefd(lua_State* L, int idx) { + luaL_Stream* p = lua::tofile(L, idx); + return file_handle::from_file(p->f).value(); + } + + // ---- request_id packing ---- + // high 32 bits = buf_ref, low 32 bits = udata_ref + + static uint64_t make_request_id(int buf_ref, int udata_ref) { + return (static_cast(static_cast(buf_ref)) << 32) + | static_cast(static_cast(udata_ref)); + } + + static int get_buf_ref(uint64_t id) { + return static_cast(static_cast(id >> 32)); + } + + static int get_udata_ref(uint64_t id) { + return static_cast(static_cast(id & 0xFFFFFFFF)); + } + + static void unref_buf(lua_async& as, int ref) { + if (ref) luaref_unref(as.refs, ref); + } + + // ---- read_buf ---- + + struct read_buf { + void* ud; + lua_Alloc allocf; + char* buf; + size_t len; + + static read_buf* create(lua_State* L, size_t len) { + lua_Alloc allocf; + void* ud = nullptr; + allocf = lua_getallocf(L, &ud); + read_buf* self = static_cast(allocf(ud, NULL, 0, sizeof(read_buf))); + self->ud = ud; + self->allocf = allocf; + self->len = len; + self->buf = static_cast(allocf(ud, NULL, 0, len + 1)); + self->buf[len] = '\0'; + return self; + } + void push_string(lua_State* L, size_t bytes) { +#if LUA_VERSION_NUM >= 505 + buf[bytes] = '\0'; + lua_pushexternalstring(L, buf, bytes, allocf, ud); +#else + lua_pushlstring(L, buf, bytes); + allocf(ud, buf, len + 1, 0); +#endif + buf = nullptr; + allocf(ud, this, sizeof(read_buf), 0); + } + void destroy() { + if (buf) allocf(ud, buf, len + 1, 0); + allocf(ud, this, sizeof(read_buf), 0); + } + }; + + static char* alloc_read_buf(lua_State* L, lua_async& as, int udata_idx, lua_Integer len, uint64_t& out_id) { + read_buf* rb = read_buf::create(L, static_cast(len)); + lua_pushlightuserdata(L, rb); + int buf_r = luaref_ref(as.refs, L); + lua_pushvalue(L, udata_idx); + int udata_r = luaref_ref(as.refs, L); + out_id = make_request_id(buf_r, udata_r); + return rb->buf; + } + + // ---- write_buf drain helpers ---- + + // Submit all queued entries as a single writev. wb.fd and wb.lua_reqid must already be set. + // Builds wb.iov_cache from the current queue (honouring per-entry offsets) and submits once. + static bool wb_submit_all(lua_async& as, async::write_buf& wb) { + if (wb.q.empty()) return true; + size_t n = wb.q.size(); + wb.iov_cache = dynarray(n); + size_t i = 0; + for (auto& e : wb.q) { + wb.iov_cache[i++].set(e.data + e.offset, e.len - e.offset); + } + if (!as.handle->submit_writev(wb.fd, span(wb.iov_cache.data(), n), wb.lua_reqid)) { + return false; + } + wb.in_flight = true; + return true; + } + + // Handle writev completion for a write_buf-managed packed id. + // Returns true → emit the Lua-visible completion. + // Returns false → still draining, no Lua completion yet. + static bool wb_on_completion(lua_State* L, lua_async& as, uint64_t packed_id, async::write_buf& wb, async::async_status status, size_t bytes) { + wb.in_flight = false; + int buf_r = get_buf_ref(packed_id); + + if (status != async::async_status::success) { + for (auto& e : wb.q) luaL_unref(L, LUA_REGISTRYINDEX, e.str_ref); + wb.q.clear(); + wb.buffered = 0; + unref_buf(as, buf_r); + return true; + } + + if (wb.q.empty()) { + unref_buf(as, buf_r); + return true; + } + + // Consume bytes_transferred from the front of the queue, honouring + // partial writes across multiple entries. + size_t remaining = bytes; + while (!wb.q.empty() && remaining > 0) { + async::write_buf::entry& front = wb.q.front(); + size_t avail = front.len - front.offset; + if (remaining >= avail) { + remaining -= avail; + luaL_unref(L, LUA_REGISTRYINDEX, front.str_ref); + wb.buffered -= front.len; + wb.q.pop_front(); + } else { + front.offset += remaining; + remaining = 0; + } + } + + if (!wb.q.empty()) { + // Still data to send (partial write): resubmit all remaining entries. + if (!wb_submit_all(as, wb)) { + // resubmit 失败:释放队列中所有字符串引用,清空队列,避免资源泄漏 + for (auto& e : wb.q) luaL_unref(L, LUA_REGISTRYINDEX, e.str_ref); + wb.q.clear(); + wb.buffered = 0; + unref_buf(as, buf_r); + return true; + } + return false; + } + + unref_buf(as, buf_r); + return true; + } + + // udata_r is consumed (unref'd) by this call. + static void push_udata(lua_State* L, lua_async& as, int udata_r) { + if (udata_r) { + luaref_get(as.refs, L, udata_r); + luaref_unref(as.refs, udata_r); + } else { + lua_pushnil(L); + } + } + + // ---- completion iterator ---- + + static int async_completions(lua_State* L) { + auto& as = *(lua_async*)lua_touserdata(L, lua_upvalueindex(1)); + again: + if (as.i >= as.n) return 0; + + const auto& c = as.completions[as.i]; + as.i++; + int buf_r = get_buf_ref(c.request_id); + int udata_r = get_udata_ref(c.request_id); + + if (c.op == async::async_op::writev) { + async::write_buf* wb = nullptr; + if (buf_r) { + luaref_get(as.refs, L, buf_r); + if (lua_type(L, -1) == LUA_TUSERDATA) { + void* p = luaL_testudata(L, -1, reflection::name_v.data()); + if (p) wb = lua::udata_align(p); + } + lua_pop(L, 1); + } + if (wb) { + async::async_status st = c.status; + size_t bytes = c.bytes_transferred; + bool done = wb_on_completion(L, as, c.request_id, *wb, st, bytes); + if (!done) goto again; + lua_pushinteger(L, static_cast(std::to_underlying(c.op))); + push_udata(L, as, udata_r); + lua_pushinteger(L, static_cast(std::to_underlying(st))); + lua_pushinteger(L, 0); + lua_pushinteger(L, static_cast(c.error_code)); + return 5; + } + } + + if (c.op == async::async_op::readv) { + // readv completion: commit bytes to the ring_buf and report as OP_READV. + async::ring_buf* rb = nullptr; + if (buf_r) { + luaref_get(as.refs, L, buf_r); + if (lua_type(L, -1) == LUA_TUSERDATA) { + void* p = luaL_testudata(L, -1, reflection::name_v.data()); + if (p) rb = lua::udata_align(p); + } + lua_pop(L, 1); + luaref_unref(as.refs, buf_r); + } + if (rb && c.status == async::async_status::success) rb->commit(c.bytes_transferred); + lua_pushinteger(L, static_cast(std::to_underlying(c.op))); + push_udata(L, as, udata_r); + lua_pushinteger(L, static_cast(std::to_underlying(c.status))); + lua_pushinteger(L, static_cast(c.bytes_transferred)); + lua_pushinteger(L, static_cast(c.error_code)); + return 5; + } + + lua_pushinteger(L, static_cast(std::to_underlying(c.op))); + push_udata(L, as, udata_r); + lua_pushinteger(L, static_cast(std::to_underlying(c.status))); + + switch (c.op) { + case async::async_op::accept: + if (c.status == async::async_status::success) { + lua_socket::newfd(L, static_cast(c.bytes_transferred)); + } else { + lua_pushinteger(L, 0); + } + break; + case async::async_op::connect: + unref_buf(as, buf_r); + lua_pushinteger(L, static_cast(c.bytes_transferred)); + break; + case async::async_op::read: { + async::ring_buf* rb = nullptr; + if (buf_r) { + luaref_get(as.refs, L, buf_r); + if (lua_type(L, -1) == LUA_TUSERDATA) { + void* p = luaL_testudata(L, -1, reflection::name_v.data()); + if (p) rb = lua::udata_align(p); + } + lua_pop(L, 1); + luaref_unref(as.refs, buf_r); + } + if (rb && c.status == async::async_status::success) rb->commit(c.bytes_transferred); + lua_pushinteger(L, static_cast(c.bytes_transferred)); + break; + } + case async::async_op::file_read: { + read_buf* rb = nullptr; + if (buf_r) { + luaref_get(as.refs, L, buf_r); + rb = static_cast(lua_touserdata(L, -1)); + lua_pop(L, 1); + luaref_unref(as.refs, buf_r); + } + if (c.status == async::async_status::success && rb) { + rb->push_string(L, c.bytes_transferred); + } else { + if (rb) rb->destroy(); + lua_pushinteger(L, static_cast(c.bytes_transferred)); + } + break; + } + default: + unref_buf(as, buf_r); + lua_pushinteger(L, static_cast(c.bytes_transferred)); + break; + } + lua_pushinteger(L, static_cast(c.error_code)); + return 5; + } + + // ---- fd helpers ---- + + // Accept both socket userdata (from bee.socket) and light userdata (from channel:fd()). + static net::fd_t checkfd_any(lua_State* L, int idx) { + if (lua_type(L, idx) == LUA_TLIGHTUSERDATA) { + return static_cast(reinterpret_cast(lua_touserdata(L, idx))); + } + return lua_socket::checkfd(L, idx); + } + + // Make a packed request_id and pin both values. + static uint64_t pin(lua_State* L, lua_async& as, int buf_idx, int udata_idx) { + lua_pushvalue(L, buf_idx); + int buf_r = luaref_ref(as.refs, L); + lua_pushvalue(L, udata_idx); + int udata_r = luaref_ref(as.refs, L); + return make_request_id(buf_r, udata_r); + } + + // Make a packed request_id with no buffer (buf_ref=0). + static uint64_t pin_udata(lua_State* L, lua_async& as, int udata_idx) { + lua_pushvalue(L, udata_idx); + int udata_r = luaref_ref(as.refs, L); + return make_request_id(0, udata_r); + } + + // Unref both refs from a packed id (used on submit failure). + static void pin_release(lua_async& as, uint64_t id) { + unref_buf(as, get_buf_ref(id)); + unref_buf(as, get_udata_ref(id)); + } + + // submit_write(asfd, wb, fd, udata) + // Drains wb's queue. C layer submits all entries at once via writev, then emits one Lua completion. + static int async_submit_write(lua_State* L) { + auto& as = lua::checkudata(L, 1); + auto& wb = lua::checkudata(L, 2); + net::fd_t fd = lua_socket::checkfd(L, 3); + luaL_checkany(L, 4); + + if (wb.in_flight || wb.q.empty()) { + lua_pushboolean(L, 1); + return 1; + } + + uint64_t id = pin(L, as, 2, 4); + wb.fd = fd; + wb.lua_reqid = id; + + if (!wb_submit_all(as, wb)) { + pin_release(as, id); + return lua::return_net_error(L, "submit_write"); + } + lua_pushboolean(L, 1); + return 1; + } + + // submit_read(asfd, rb, fd, udata) + // Handles ring_buf wrap-around automatically: when the free region wraps around + // the end of the buffer, submits two iovecs covering both segments at once. + // Falls back to a single-buffer read when there is no wrap-around. + static int async_submit_read(lua_State* L) { + auto& as = lua::checkudata(L, 1); + auto& rb = lua::checkudata(L, 2); + net::fd_t fd = lua_socket::checkfd(L, 3); + luaL_checkany(L, 4); + + size_t len1 = rb.write_len(); + if (len1 == 0) { + lua_pushboolean(L, 0); + return 1; + } + + size_t free_space = rb.free_cap(); + size_t len2 = (len1 < free_space) ? (free_space - len1) : 0; + + if (len2 == 0) { + // No wrap-around: fall back to a single-buffer submit_read. + uint64_t id = pin(L, as, 2, 4); + if (!as.handle->submit_read(fd, rb.write_ptr(), len1, id)) { + pin_release(as, id); + return lua::return_net_error(L, "submit_read"); + } + lua_pushboolean(L, 1); + return 1; + } + + // Two segments: first runs to the buffer end, second wraps to the beginning. + net::socket::iobuf bufs[2]; + bufs[0].set(rb.write_ptr(), len1); + bufs[1].set(rb.data, len2); + + uint64_t id = pin(L, as, 2, 4); + if (!as.handle->submit_readv(fd, span(bufs, 2), id)) { + pin_release(as, id); + return lua::return_net_error(L, "submit_read"); + } + lua_pushboolean(L, 1); + return 1; + } + + static int async_submit_accept(lua_State* L) { + auto& as = lua::checkudata(L, 1); + net::fd_t fd = lua_socket::checkfd(L, 2); + luaL_checkany(L, 3); + uint64_t id = pin_udata(L, as, 3); + if (!as.handle->submit_accept(fd, id)) { + pin_release(as, id); + return lua::return_net_error(L, "submit_accept"); + } + lua_pushboolean(L, 1); + return 1; + } + + static int async_submit_connect(lua_State* L) { + auto& as = lua::checkudata(L, 1); + net::fd_t fd = lua_socket::checkfd(L, 2); + net::endpoint stack_ep; + const net::endpoint* ep_ptr; + int udata_idx; + if (lua_type(L, 3) == LUA_TUSERDATA) { + ep_ptr = &lua_socket::to_endpoint(L, 3, stack_ep); + udata_idx = 4; + lua_pushvalue(L, 3); + } else { + auto name = lua::checkstrview(L, 3); + auto port = lua::checkinteger(L, 4); + udata_idx = 5; + auto& ep = lua_socket::new_endpoint(L); + if (!net::endpoint::ctor_hostname(ep, name, port)) + return lua::return_error(L, "invalid endpoint"); + ep_ptr = &ep; + } + luaL_checkany(L, udata_idx); + uint64_t id = pin(L, as, lua_gettop(L), udata_idx); + lua_pop(L, 1); + if (!as.handle->submit_connect(fd, *ep_ptr, id)) { + pin_release(as, id); + return lua::return_net_error(L, "submit_connect"); + } + lua_pushboolean(L, 1); + return 1; + } + + static int async_submit_file_read(lua_State* L) { + auto& as = lua::checkudata(L, 1); + file_handle::value_type fd = tofilefd(L, 2); + lua_Integer len = luaL_checkinteger(L, 3); + lua_Integer offset = luaL_optinteger(L, 4, 0); + luaL_checkany(L, 5); + if (len <= 0) return luaL_error(L, "buffer size must be positive"); + uint64_t id = 0; + void* buffer = alloc_read_buf(L, as, 5, len, id); + if (!as.handle->submit_file_read(fd, buffer, static_cast(len), static_cast(offset), id)) { + int buf_r = get_buf_ref(id); + if (buf_r) { + luaref_get(as.refs, L, buf_r); + read_buf* rb = static_cast(lua_touserdata(L, -1)); + lua_pop(L, 1); + if (rb) rb->destroy(); + } + pin_release(as, id); + return lua::return_net_error(L, "submit_file_read"); + } + lua_pushboolean(L, 1); + return 1; + } + + static int async_submit_file_write(lua_State* L) { + auto& as = lua::checkudata(L, 1); + file_handle::value_type fd = tofilefd(L, 2); + size_t len = 0; + const char* data = luaL_checklstring(L, 3, &len); + lua_Integer offset = luaL_optinteger(L, 4, 0); + luaL_checkany(L, 5); + uint64_t id = pin(L, as, 3, 5); + if (!as.handle->submit_file_write(fd, data, len, static_cast(offset), id)) { + pin_release(as, id); + return lua::return_net_error(L, "submit_file_write"); + } + lua_pushboolean(L, 1); + return 1; + } + + // submit_poll(asfd, fd, udata) + static int async_submit_poll(lua_State* L) { + auto& as = lua::checkudata(L, 1); + net::fd_t fd = checkfd_any(L, 2); + luaL_checkany(L, 3); + uint64_t id = pin_udata(L, as, 3); + if (!as.handle->submit_poll(fd, id)) { + pin_release(as, id); + return lua::return_net_error(L, "submit_poll"); + } + lua_pushboolean(L, 1); + return 1; + } + + static int async_poll(lua_State* L) { + auto& as = lua::checkudata(L, 1); + as.i = 0; + as.n = as.handle->poll(span(as.completions.data(), as.completions.size())); + lua_getiuservalue(L, 1, 1); + return 1; + } + + static int async_wait(lua_State* L) { + auto& as = lua::checkudata(L, 1); + int timeout = lua::optinteger(L, 2); + as.i = 0; + as.n = as.handle->wait(span(as.completions.data(), as.completions.size()), timeout); + lua_getiuservalue(L, 1, 1); + return 1; + } + + static int async_associate(lua_State* L) { +#if defined(_WIN32) + auto& as = lua::checkudata(L, 1); + net::fd_t fd = checkfd_any(L, 2); + lua_pushboolean(L, as.handle->associate(fd) ? 1 : 0); +#else + lua_pushboolean(L, 1); +#endif + return 1; + } + + static int async_associate_file(lua_State* L) { +#if defined(_WIN32) + auto& as = lua::checkudata(L, 1); + file_handle::value_type fd = tofilefd(L, 2); + auto [ov_fh, writable] = as.handle->associate_file(fd); + if (!ov_fh) + return lua::return_net_error(L, "associate_file"); + auto fm = writable ? file_handle::mode::write : file_handle::mode::read; + FILE* fp = ov_fh.to_file(fm); + if (!fp) { + ov_fh.close(); + return lua::return_net_error(L, "associate_file"); + } + luaL_Stream* stream = lua::tofile(L, 2); + fclose(stream->f); + stream->f = fp; + lua_pushboolean(L, 1); +#else + lua_pushboolean(L, 1); +#endif + return 1; + } + + static int async_cancel(lua_State* L) { + auto& as = lua::checkudata(L, 1); + net::fd_t fd = lua_socket::checkfd(L, 2); + as.handle->cancel(fd); + return 0; + } + + static int async_stop(lua_State* L) { + auto& as = lua::checkudata(L, 1); + as.handle->stop(); + lua_pushboolean(L, 1); + return 1; + } + + static int async_mt_close(lua_State* L) { + auto& as = lua::checkudata(L, 1); + as.handle->stop(); + return 0; + } + + // ---- ring_buf (readbuf) methods ---- + + static int rb_read(lua_State* L) { + auto& rb = lua::checkudata(L, 1); + if (lua_isnoneornil(L, 2)) { + size_t n = rb.size(); + if (n == 0) { + lua_pushnil(L); + return 1; + } + luaL_Buffer b; + char* dst = luaL_buffinitsize(L, &b, n); + rb.consume(dst, n); + luaL_pushresultsize(&b, n); + return 1; + } + lua_Integer n = luaL_checkinteger(L, 2); + if (n <= 0) return luaL_error(L, "n must be positive"); + size_t ulen = static_cast(n); + if (rb.size() < ulen) { + lua_pushnil(L); + return 1; + } + luaL_Buffer b; + char* dst = luaL_buffinitsize(L, &b, ulen); + rb.consume(dst, ulen); + luaL_pushresultsize(&b, ulen); + return 1; + } + + static int rb_readline(lua_State* L) { + auto& rb = lua::checkudata(L, 1); + size_t seplen = 0; + const char* sep = luaL_optlstring(L, 2, "\r\n", &seplen); + if (seplen == 0) return luaL_error(L, "separator must not be empty"); + size_t n = rb.find(sep, seplen); + if (n == 0) { + lua_pushnil(L); + return 1; + } + luaL_Buffer b; + char* dst = luaL_buffinitsize(L, &b, n); + rb.consume(dst, n); + luaL_pushresultsize(&b, n); + return 1; + } + + static int async_readbuf_create(lua_State* L) { + lua_Integer bufsize = luaL_checkinteger(L, 1); + if (bufsize <= 0) return luaL_error(L, "bufsize must be positive"); + lua::newudata(L, static_cast(bufsize)); + return 1; + } + + // ---- write_buf (writebuf) methods ---- + + // async.writebuf(hwm) -> write_buf userdata + static int async_writebuf_create(lua_State* L) { + lua_Integer hwm = luaL_optinteger(L, 1, 64 * 1024); + if (hwm <= 0) return luaL_error(L, "hwm must be positive"); + auto& wb = lua::newudata(L); + wb.hwm = static_cast(hwm); + return 1; + } + + // wb:write(data) -> bool (true = buffered >= hwm after enqueue) + static int wb_write(lua_State* L) { + auto& wb = lua::checkudata(L, 1); + size_t len = 0; + const char* data = luaL_checklstring(L, 2, &len); + if (len == 0) { + lua_pushboolean(L, 0); + return 1; + } + lua_pushvalue(L, 2); + int str_ref = luaL_ref(L, LUA_REGISTRYINDEX); + async::write_buf::entry e; + e.data = data; + e.len = len; + e.offset = 0; + e.str_ref = str_ref; + wb.q.push_back(e); + wb.buffered += len; + lua_pushboolean(L, wb.buffered >= wb.hwm ? 1 : 0); + return 1; + } + + // wb:buffered() -> integer + static int wb_buffered(lua_State* L) { + auto& wb = lua::checkudata(L, 1); + lua_pushinteger(L, static_cast(wb.buffered)); + return 1; + } + + // wb:close() -- release all queued strings (called on stream close) + static int wb_close(lua_State* L) { + auto& wb = lua::checkudata(L, 1); + for (auto& e : wb.q) luaL_unref(L, LUA_REGISTRYINDEX, e.str_ref); + wb.q.clear(); + wb.buffered = 0; + return 0; + } + + // ---- metatable / module ---- + + static void metatable(lua_State* L) { + static luaL_Reg lib[] = { + { "submit_write", async_submit_write }, + { "submit_read", async_submit_read }, + { "submit_accept", async_submit_accept }, + { "submit_connect", async_submit_connect }, + { "submit_file_read", async_submit_file_read }, + { "submit_file_write", async_submit_file_write }, + { "submit_poll", async_submit_poll }, + { "associate", async_associate }, + { "associate_file", async_associate_file }, + { "cancel", async_cancel }, + { "poll", async_poll }, + { "wait", async_wait }, + { "stop", async_stop }, + { NULL, NULL } + }; + luaL_newlibtable(L, lib); + luaL_setfuncs(L, lib, 0); + lua_setfield(L, -2, "__index"); + static luaL_Reg mt[] = { + { "__close", async_mt_close }, + { NULL, NULL } + }; + luaL_setfuncs(L, mt, 0); + } + + static int async_create(lua_State* L) { + lua_Integer max_completions = luaL_optinteger(L, 1, 64); + if (max_completions <= 0) + return lua::return_error(L, "max_completions is less than or equal to zero."); + auto handle = async::create(); + if (!handle) + return lua::return_error(L, "failed to create async backend"); + lua::newudata(L, static_cast(max_completions)); + auto& as = lua::checkudata(L, -1); + as.handle = std::move(handle); + as.refs = luaref_init(L); + lua_pushvalue(L, -1); + lua_pushcclosure(L, async_completions, 1); + lua_setiuservalue(L, -2, 1); + return 1; + } + + static int luaopen(lua_State* L) { + struct luaL_Reg l[] = { + { "create", async_create }, + { "readbuf", async_readbuf_create }, + { "writebuf", async_writebuf_create }, + { NULL, NULL }, + }; + luaL_newlib(L, l); + +#define SETENUM(E, V) \ + lua_pushinteger(L, static_cast(std::to_underlying(V))); \ + lua_setfield(L, -2, #E) + + SETENUM(SUCCESS, async::async_status::success); + SETENUM(CLOSE, async::async_status::close); + SETENUM(ERROR, async::async_status::error); + SETENUM(CANCEL, async::async_status::cancel); + + SETENUM(OP_READ, async::async_op::read); + SETENUM(OP_READV, async::async_op::readv); + SETENUM(OP_WRITE, async::async_op::write); + SETENUM(OP_WRITEV, async::async_op::writev); + SETENUM(OP_ACCEPT, async::async_op::accept); + SETENUM(OP_CONNECT, async::async_op::connect); + SETENUM(OP_FILE_READ, async::async_op::file_read); + SETENUM(OP_FILE_WRITE, async::async_op::file_write); + SETENUM(OP_POLL, async::async_op::fd_poll); +#undef SETENUM + return 1; + } +} + +DEFINE_LUAOPEN(async) + +namespace bee::lua { + template <> + struct udata { + static inline int nupvalue = 1; + static inline auto metatable = bee::lua_async::metatable; + }; + template <> + struct udata { + static inline auto metatable = [](lua_State* L) { + static luaL_Reg lib[] = { + { "read", lua_async::rb_read }, + { "readline", lua_async::rb_readline }, + { NULL, NULL } + }; + luaL_newlibtable(L, lib); + luaL_setfuncs(L, lib, 0); + lua_setfield(L, -2, "__index"); + }; + }; + template <> + struct udata { + static inline auto metatable = [](lua_State* L) { + static luaL_Reg lib[] = { + { "write", lua_async::wb_write }, + { "buffered", lua_async::wb_buffered }, + { "close", lua_async::wb_close }, + { NULL, NULL } + }; + luaL_newlibtable(L, lib); + luaL_setfuncs(L, lib, 0); + lua_setfield(L, -2, "__index"); + }; + }; +} diff --git a/binding/lua_socket.cpp b/binding/lua_socket.cpp index f857f2cc..f961af14 100644 --- a/binding/lua_socket.cpp +++ b/binding/lua_socket.cpp @@ -528,15 +528,15 @@ namespace bee::lua_socket { return 1; } case endpoint_ctor::hostname: { - auto name = lua::checkstrview(L, 2); - auto port = lua::checkinteger(L, 3); + auto name = lua::checkstrview(L, 2); + auto port = lua::checkinteger(L, 3); auto af_hint = net::family::unknown; if (!lua_isnoneornil(L, 4)) { - static const char* const af_opts[] = { "inet", "inet6", NULL }; + static const char* const af_opts[] = { "inet", "inet6", NULL }; static const net::family af_values[] = { net::family::inet, net::family::inet6 }; - af_hint = af_values[luaL_checkoption(L, 4, NULL, af_opts)]; + af_hint = af_values[luaL_checkoption(L, 4, NULL, af_opts)]; } - auto& ep = lua::newudata(L); + auto& ep = lua::newudata(L); if (!net::endpoint::ctor_hostname(ep, name, port, af_hint)) { return 0; } @@ -608,6 +608,22 @@ namespace bee::lua_socket { luaL_setfuncs(L, lib, 0); return 1; } + + net::fd_t& newfd(lua_State* L, net::fd_t fd) { + return lua::newudata(L, fd); + } + + net::fd_t& checkfd(lua_State* L, int idx) { + return lua::checkudata(L, idx); + } + + net::endpoint& new_endpoint(lua_State* L) { + return lua::newudata(L); + } + + net::endpoint& to_endpoint(lua_State* L, int idx, net::endpoint& ep) { + return fd::to_endpoint(L, idx, ep); + } } DEFINE_LUAOPEN(socket) diff --git a/compile/common.lua b/compile/common.lua index 9e54a64c..6b013dea 100644 --- a/compile/common.lua +++ b/compile/common.lua @@ -175,8 +175,13 @@ lm:source_set "source_bee" { need { "osx", "posix", - } - } + }, + lm.async_backend == "kqueue" and { + "!bee/async/async_osx.cpp", + "bee/async/async_bsd.cpp" + }, + }, + defines = lm.async_backend == "kqueue" and "BEE_ASYNC_BACKEND_KQUEUE", }, ios = { sources = { @@ -193,8 +198,12 @@ lm:source_set "source_bee" { need { "linux", "posix", - } - } + }, + lm.async_backend == "epoll" and { + "!bee/async/async_uring_linux.cpp", + }, + }, + defines = lm.async_backend == "epoll" and "BEE_ASYNC_BACKEND_EPOLL", }, android = { sources = need { diff --git a/meta/async.lua b/meta/async.lua new file mode 100644 index 00000000..5aa8240e --- /dev/null +++ b/meta/async.lua @@ -0,0 +1,196 @@ +---@meta bee.async + +---异步I/O库 +---跨平台的异步I/O API,在macOS上使用GCD实现,在Windows上使用IOCP实现,在Linux上使用io_uring/epoll实现 +---@class bee.async +---@field SUCCESS integer 操作成功 +---@field CLOSE integer 连接关闭 +---@field ERROR integer 操作错误 +---@field CANCEL integer 操作取消 +---@field OP_READ integer 流式读操作 +---@field OP_READV integer 流式读操作(ring buffer 回绕时双段) +---@field OP_WRITE integer 单次写操作 +---@field OP_WRITEV integer writebuf 写操作 +---@field OP_ACCEPT integer accept 操作 +---@field OP_CONNECT integer connect 操作 +---@field OP_FILE_READ integer 文件读操作 +---@field OP_FILE_WRITE integer 文件写操作 +---@field OP_POLL integer poll 操作 +local async = {} + +---异步I/O实例对象 +---@class bee.async.fd +local asfd = {} + +---提交流式异步读操作(使用 ring buffer,自动处理回绕) +---若 ring buffer 空闲不足(背压)则不投递,返回 false。 +---当空闲区域跨越缓冲区末尾时(回绕场景),自动拆分为两段一次提交, +---减少系统调用次数;无回绕时等同于单段读取。 +---回绕时 completion 的 op 为 OP_READV,无回绕时为 OP_READ。 +---底层 submit 系统调用失败时返回 nil, err。 +---@param rb bee.async.readbuf 接收缓冲区对象 +---@param fd bee.socket.fd socket 对象 +---@param udata any 用户自定义数据,completion 时原样返回 +---@return boolean? # 成功投递返回true,背压返回false,系统调用失败返回nil +---@return string? # 系统调用失败时的错误消息 +function asfd:submit_read(rb, fd, udata) +end + +---提交 writebuf 异步写操作 +---将 wb 队列中的所有数据写入 fd。C 层自动 drain(包括 partial write 重试), +---队列清空后产生一次 Lua-visible completion(bytes=0)。 +---若 wb 为空或已有 in-flight 请求,则直接返回 true 不投递。 +---@param wb bee.async.writebuf 写缓冲区对象 +---@param fd bee.socket.fd socket 对象 +---@param udata any 用户自定义数据,队列清空时作为 completion 的 udata 返回 +---@return boolean? # 成功返回true,失败返回nil +---@return string? # 错误消息 +function asfd:submit_write(wb, fd, udata) +end + +---提交异步accept操作 +---@param listen_fd bee.socket.fd 监听 socket 对象 +---@param udata any 用户自定义数据,completion 时原样返回 +---@return boolean? # 成功返回true,失败返回nil +---@return string? # 错误消息 +function asfd:submit_accept(listen_fd, udata) +end + +---提交异步connect操作 +---@param fd bee.socket.fd socket 对象 +---@param host string 目标主机名或IP地址 +---@param port integer 目标端口号 +---@param udata any 用户自定义数据,completion 时原样返回 +---@return boolean? # 成功返回true,失败返回nil +---@return string? # 错误消息 +---@overload fun(self: bee.async.fd, fd: bee.socket.fd, ep: bee.socket.endpoint, udata: any): boolean?, string? +function asfd:submit_connect(fd, host, port, udata) +end + +---将文件关联到当前异步I/O实例(仅 Windows/IOCP) +---Windows 下会就地替换传入文件对象的底层句柄为 overlapped/IOCP 关联句柄 +---必须在首次提交文件 I/O 操作之前调用;非 Windows 平台为 no-op 并直接返回 true +---@param fd file* 原始文件对象(通过 io.open 获取;调用后原对象被就地更新) +---@return boolean? # 成功返回true,失败返回nil +---@return string? # 错误消息 +function asfd:associate_file(fd) +end + +---提交异步文件读操作 +---@param fd file* 文件对象 +---@param len integer 读取长度 +---@param offset? integer 文件偏移量,默认为0 +---@param udata any 用户自定义数据,completion 时原样返回 +---@return boolean? # 成功返回true,失败返回nil +---@return string? # 错误消息 +function asfd:submit_file_read(fd, len, offset, udata) +end + +---提交异步文件写操作 +---@param fd file* 文件对象 +---@param data string 要写入的数据 +---@param offset? integer 文件偏移量,默认为0 +---@param udata any 用户自定义数据,completion 时原样返回 +---@return boolean? # 成功返回true,失败返回nil +---@return string? # 错误消息 +function asfd:submit_file_write(fd, data, offset, udata) +end + +---提交异步 poll 操作 +---仅监听 fd 的可读性,不消费任何数据。当 fd 变为可读时产生一次 completion。 +---适用于监听 channel 的 ev fd,收到通知后由调用方自行调用 channel:pop() 消费数据。 +---@param fd bee.socket.fd socket 对象(或 channel:fd() 返回的 fd) +---@param udata any 用户自定义数据,completion 时原样返回 +---@return boolean? # 成功返回true,失败返回nil +---@return string? # 错误消息 +function asfd:submit_poll(fd, udata) +end + +---将 socket 关联到当前异步I/O实例(仅 Windows/IOCP) +---必须在首次提交任何 I/O 操作之前调用 +---@param fd bee.socket.fd socket 对象 +---@return boolean # 成功返回true,失败返回false +function asfd:associate(fd) +end + +---取消指定 socket 上的所有待处理 I/O 操作(仅 Windows/IOCP) +---通常在关闭 socket 前调用,以确保所有 overlapped 操作及时完成 +---@param fd bee.socket.fd socket 对象 +function asfd:cancel(fd) +end + +---轮询已完成的I/O事件(非阻塞) +---accept 操作完成时第四个返回值为新的 socket userdata,file_read 完成时为读取到的字符串数据,其他操作为 bytes_transferred +---@return fun(): integer, any, integer, integer|bee.socket.fd|string, integer # 迭代器,产生 (op, udata, status, bytes_transferred|accepted_socket|read_data, error_code) +function asfd:poll() +end + +---等待已完成的I/O事件(阻塞) +---accept 操作完成时第四个返回值为新的 socket userdata,file_read 完成时为读取到的字符串数据,其他操作为 bytes_transferred +---@param timeout? integer 超时时间,单位为毫秒,-1表示无限等待 +---@return fun(): integer, any, integer, integer|bee.socket.fd|string, integer # 迭代器,产生 (op, udata, status, bytes_transferred|accepted_socket|read_data, error_code) +function asfd:wait(timeout) +end + +---停止异步实例 +---@return boolean +function asfd:stop() +end + +---创建写缓冲区对象 +---@param hwm? integer 高水位阈值(字节数),默认 65536 +---@return bee.async.writebuf +function async.writebuf(hwm) +end + +---写缓冲区对象 +---由 async.writebuf() 创建,通过 wb:write / asfd:submit_write 使用 +---@class bee.async.writebuf +local writebuf = {} + +---将数据追加到写缓冲区 +---@param data string 要发送的数据 +---@return boolean # true 表示缓冲字节数 >= hwm(调用方应在 Lua 侧背压等待) +function writebuf:write(data) +end + +---返回当前缓冲队列中的总字节数 +---@return integer +function writebuf:buffered() +end + +---释放队列中所有待发字符串(在流关闭时调用) +function writebuf:close() +end + + +---@param bufsize integer 缓冲区大小(会向上取整到最近的2的幂) +---@return bee.async.readbuf +function async.readbuf(bufsize) +end + +---接收缓冲区对象(ring buffer) +---由 async.readbuf() 创建,通过 submit_read / rb:read() 使用 +---@class bee.async.readbuf +local readbuf = {} + +---从 ring buffer 读取数据 +---@param n? integer 读取字节数,nil 表示读取全部可用数据 +---@return string? # 成功返回数据字符串,数据不足返回 nil +function readbuf:read(n) +end + +---从 ring buffer 读取一行(含末尾分隔符) +---@param sep? string 行分隔符,默认为 "\r\n" +---@return string? # 找到分隔符则返回该行(含分隔符),否则返回 nil +function readbuf:readline(sep) +end + +---创建异步I/O实例 +---@param max_completions? integer 最大完成事件数量,默认为64 +---@return bee.async.fd? # 异步I/O实例 +---@return string? # 错误消息 +function async.create(max_completions) +end + +return async diff --git a/test.lua b/test.lua index 1edab56b..402074c0 100644 --- a/test.lua +++ b/test.lua @@ -1,15 +1,40 @@ local subprocess = require "bee.subprocess" local platform = require "bee.platform" +local fs = require "bee.filesystem" -local luaexe = platform.os == "windows" +local luaexe = fs.absolute(platform.os == "windows" and "./build/bin/bootstrap.exe" - or "./build/bin/bootstrap" + or "./build/bin/bootstrap"):string() -local process = assert(subprocess.spawn { - luaexe, "test/test.lua", arg, - stdout = io.stdout, - stderr = "stdout", -}) +local bench = false +local bench_args = {} +local test_args = {} +for i, v in ipairs(arg) do + if v == "-bench" then + bench = true + elseif bench then + bench_args[#bench_args+1] = v + else + test_args[#test_args+1] = v + end +end + +local process +if bench then + process = assert(subprocess.spawn { + luaexe, + "bench.lua", bench_args, + stdout = io.stdout, + stderr = "stdout", + cwd = "benchmark" + }) +else + process = assert(subprocess.spawn { + luaexe, "test/test.lua", test_args, + stdout = io.stdout, + stderr = "stdout", + }) +end local code = process:wait() if code ~= 0 then diff --git a/test/test.lua b/test/test.lua index b5c100c8..24456811 100644 --- a/test/test.lua +++ b/test/test.lua @@ -39,6 +39,7 @@ require "test_thread" require "test_subprocess" require "test_socket" require "test_epoll" +require "test_async" require "test_filewatch" require "test_time" require "test_channel" diff --git a/test/test_async.lua b/test/test_async.lua new file mode 100644 index 00000000..e1e9d48e --- /dev/null +++ b/test/test_async.lua @@ -0,0 +1,669 @@ +local lt = require "ltest" +local async = require "bee.async" +local socket = require "bee.socket" +local time = require "bee.time" +local select = require "bee.select" +local platform = require "bee.platform" + +local m = lt.test "async" + +local SUCCESS = async.SUCCESS +local CLOSE = async.CLOSE +local ERROR = async.ERROR +local CANCEL = async.CANCEL + +local function SimpleServer(as, protocol, ...) + local fd = assert(socket.create(protocol)) + assert(as:associate(fd)) + assert(fd:bind(...)) + assert(fd:listen()) + return fd +end + +local function SimpleClient(as, protocol, ...) + local fd = assert(socket.create(protocol)) + assert(as:associate(fd)) + local ok, err = fd:connect(...) + assert(ok ~= nil, err) + return fd +end + +local function wait_accept(as, sfd) + local s = select.create() + s:event_add(sfd, select.SELECT_READ) + s:wait() + local newfd = assert(sfd:accept()) + assert(as:associate(newfd)) + return newfd +end + +local function wait_completion(as, timeout) + timeout = timeout or 1000 + local start = time.monotonic() + while time.monotonic() - start < timeout do + for op, token, st, data, errcode in as:wait(100) do + return op, token, st, data, errcode + end + end + lt.failure("wait_completion timeout") +end + +--- 测试创建和基本属性 +function m.test_create() + lt.assertFailed("max_completions is less than or equal to zero.", async.create(-1)) + lt.assertFailed("max_completions is less than or equal to zero.", async.create(0)) + local as = async.create(64) + lt.assertIsUserdata(as) +end + +--- 测试枚举值 +function m.test_enum() + lt.assertIsNumber(SUCCESS) + lt.assertIsNumber(CLOSE) + lt.assertIsNumber(ERROR) + lt.assertIsNumber(CANCEL) + -- 枚举值应该是不同的 + lt.assertEquals(SUCCESS ~= CLOSE, true) + lt.assertEquals(SUCCESS ~= ERROR, true) + lt.assertEquals(SUCCESS ~= CANCEL, true) + lt.assertEquals(CLOSE ~= ERROR, true) + lt.assertEquals(CLOSE ~= CANCEL, true) + lt.assertEquals(ERROR ~= CANCEL, true) +end + +--- 测试 poll 在没有完成事件时返回空 +function m.test_poll_empty() + local as = assert(async.create(64)) + local count = 0 + for _ in as:poll() do + count = count + 1 + end + lt.assertEquals(count, 0) +end + +--- 测试 wait 超时 +function m.test_wait_timeout() + local as = assert(async.create(64)) + local start = time.monotonic() + local count = 0 + for _ in as:wait(100) do + count = count + 1 + end + local elapsed = time.monotonic() - start + lt.assertEquals(count, 0) + -- 只验证确实等待了接近 timeout,避免依赖易抖动的严格上界 + lt.assertEquals(elapsed >= 50, true) +end + +--- 测试 TCP write 和 read,token 为字符串 +function m.test_tcp_write_read() + local as = assert(async.create(64)) + local sfd = SimpleServer(as, "tcp", "127.0.0.1", 0) + local cfd = SimpleClient(as, "tcp", sfd:info "socket") + local newfd = wait_accept(as, sfd) + local wb = assert(async.writebuf(64 * 1024)) + wb:write("hello") + lt.assertEquals(as:submit_write(wb, cfd, "write_token"), true) + local op, token, status, bytes = wait_completion(as) + lt.assertEquals(op, async.OP_WRITEV) + lt.assertEquals(token, "write_token") + lt.assertEquals(status, SUCCESS) + lt.assertEquals(bytes, 0) -- writebuf completion always returns 0 bytes + + -- 提交流式读操作,token 为 table + local read_token = { id = 42 } + local rb = assert(async.readbuf(64)) + lt.assertEquals(as:submit_read(rb, newfd, read_token), true) + op, token, status, bytes = wait_completion(as) + lt.assertEquals(op, async.OP_READ) + lt.assertEquals(token, read_token) + lt.assertEquals(status, SUCCESS) + lt.assertEquals(bytes, 5) + -- 从 ring buffer 精确读取 + lt.assertEquals(rb:read(5), "hello") + + newfd:close() +end + +--- 测试 TCP accept,token 为 table +function m.test_tcp_accept() + local as = assert(async.create(64)) + local sfd = SimpleServer(as, "tcp", "127.0.0.1", 0) + local _, port = sfd:info "socket":value() + + local accept_token = { op = "accept" } + lt.assertEquals(as:submit_accept(sfd, accept_token), true) + + -- 连接到服务器以触发 accept + local cfd = SimpleClient(as, "tcp", "127.0.0.1", port) + + local _, token, status, newfd = wait_completion(as) + lt.assertEquals(token, accept_token) + lt.assertEquals(status, SUCCESS) + -- accept 完成后返回新的 socket userdata + lt.assertIsUserdata(newfd) + + -- 关闭 accepted fd + newfd:close() +end + +--- 测试 TCP connect,token 为数字 +function m.test_tcp_connect() + local as = assert(async.create(64)) + local sfd = SimpleServer(as, "tcp", "127.0.0.1", 0) + local _, port = sfd:info "socket":value() + + local cfd = assert(socket.create "tcp") + assert(as:associate(cfd)) + lt.assertEquals(as:submit_connect(cfd, "127.0.0.1", port, 99), true) + + local _, token, status = wait_completion(as) + lt.assertEquals(token, 99) + lt.assertEquals(status, SUCCESS) +end + +--- 测试文件读写 +function m.test_file_read_write() + local as = assert(async.create(64)) + local fs = require "bee.filesystem" + local filepath = fs.current_path() / "test_async_file.txt" + + -- 创建测试文件 + do + local f = assert(io.open(filepath:string(), "wb")) + f:write("hello async file io") + f:close() + end + + -- 打开文件用于读取,关联到异步实例 + local rf = assert(io.open(filepath:string(), "rb")) + assert(as:associate_file(rf)) + + -- 提交文件读操作,token 为字符串 + lt.assertEquals(as:submit_file_read(rf, 128, 0, "fread"), true) + local op, token, status, bytes = wait_completion(as) + lt.assertEquals(op, async.OP_FILE_READ) + lt.assertEquals(token, "fread") + lt.assertEquals(status, SUCCESS) + lt.assertEquals(bytes, "hello async file io") + + rf:close() + + -- 打开文件用于写入,关联到异步实例 + local wf = assert(io.open(filepath:string(), "wb")) + assert(as:associate_file(wf)) + lt.assertEquals(as:submit_file_write(wf, "written by async", 0, "fwrite"), true) + op, token, status, bytes = wait_completion(as) + lt.assertEquals(op, async.OP_FILE_WRITE) + lt.assertEquals(token, "fwrite") + lt.assertEquals(status, SUCCESS) + lt.assertEquals(bytes, 16) -- "written by async" 长度为 16 + + wf:close() + + -- 验证写入内容 + local fr = assert(io.open(filepath:string(), "rb")) + lt.assertEquals(fr:read "*a", "written by async") + fr:close() + + fs.remove(filepath) +end + +--- 测试文件偏移量读写 +function m.test_file_read_write_offset() + local as = assert(async.create(64)) + local fs = require "bee.filesystem" + local filepath = fs.current_path() / "test_async_offset.txt" + + -- 创建测试文件 + do + local f = assert(io.open(filepath:string(), "wb")) + f:write("0123456789ABCDEF") + f:close() + end + + -- 从偏移量 10 读取,token 为数字 + local rf = assert(io.open(filepath:string(), "rb")) + assert(as:associate_file(rf)) + + lt.assertEquals(as:submit_file_read(rf, 6, 10, 1), true) + local _, token, status, bytes = wait_completion(as) + lt.assertEquals(token, 1) + lt.assertEquals(status, SUCCESS) + lt.assertEquals(bytes, "ABCDEF") + + rf:close() + fs.remove(filepath) +end + +--- 测试读取已关闭的连接 +function m.test_read_closed() + local as = assert(async.create(64)) + local sfd = SimpleServer(as, "tcp", "127.0.0.1", 0) + local cfd = SimpleClient(as, "tcp", sfd:info "socket") + local newfd = wait_accept(as, sfd) + + -- 关闭客户端 + cfd:close() + + -- 在服务端提交读操作应该收到 close 状态 + local rb = assert(async.readbuf(64)) + local read_token = { op = "read_closed" } + lt.assertEquals(as:submit_read(rb, newfd, read_token), true) + local _, token, status = wait_completion(as) + lt.assertEquals(token, read_token) + lt.assertEquals(status, CLOSE) + + newfd:close() +end + +--- 测试多个并发请求,token 为循环变量(数字) +function m.test_multiple_requests() + local as = assert(async.create(64)) + local sfd = SimpleServer(as, "tcp", "127.0.0.1", 0) + local _, port = sfd:info "socket":value() + + -- 创建多个客户端连接 + local clients = {} + local servers = {} + for i = 1, 3 do + clients[i] = SimpleClient(as, "tcp", "127.0.0.1", port) + servers[i] = wait_accept(as, sfd) + end + + -- 提交多个写操作,token 为数字 + local wbs = {} + for i = 1, 3 do + wbs[i] = assert(async.writebuf(64 * 1024)) + wbs[i]:write("msg"..i) + lt.assertEquals(as:submit_write(wbs[i], clients[i], i), true) + end + + -- 收集所有完成事件 + local results = {} + local deadline = time.monotonic() + 1000 + while #results < 3 and time.monotonic() < deadline do + for op, token, status, bytes in as:wait(100) do + results[#results+1] = { op = op, token = token, status = status, bytes = bytes } + end + end + lt.assertEquals(#results, 3) + + -- 验证所有请求都成功 + for _, r in ipairs(results) do + lt.assertEquals(r.status, SUCCESS) + lt.assertEquals(r.bytes, 0) -- writebuf completion always returns 0 bytes + end + + for i = 1, 3 do + clients[i]:close() + servers[i]:close() + end +end + +--- 测试 readbuf 创建和 ring buffer 基本操作 +function m.test_readbuf() + -- 无效参数 + lt.assertErrorMsgEquals("bufsize must be positive", async.readbuf, 0) + lt.assertErrorMsgEquals("bufsize must be positive", async.readbuf, -1) + + local rb = assert(async.readbuf(64)) + lt.assertIsUserdata(rb) + + -- 空 ring buffer:read 返回 nil + lt.assertEquals(rb:read(1), nil) + lt.assertEquals(rb:read(), nil) +end + +--- 测试 submit_read + rb:read 流式接收 +function m.test_stream_read() + local as = assert(async.create(64)) + local sfd = SimpleServer(as, "tcp", "127.0.0.1", 0) + local cfd = SimpleClient(as, "tcp", sfd:info "socket") + local newfd = wait_accept(as, sfd) + local rb = assert(async.readbuf(256)) + + -- 发送两段数据 + local wb = assert(async.writebuf(64 * 1024)) + wb:write("helloworld") + lt.assertEquals(as:submit_write(wb, cfd, "w1"), true) + wait_completion(as) -- write done + + -- 投递 stream read,token 为字符串 + lt.assertEquals(as:submit_read(rb, newfd, "r1"), true) + local _, token, status, bytes = wait_completion(as) + lt.assertEquals(token, "r1") + lt.assertEquals(status, SUCCESS) + lt.assertEquals(bytes >= 1, true) + + -- rb:read(n) 不足时返回 nil + local total = rb:read() -- 取全部 + lt.assertEquals(total ~= nil, true) + lt.assertEquals(#total >= 1, true) + + -- 再取时为空 + lt.assertEquals(rb:read(), nil) + + newfd:close() +end + +--- 测试 rb:readline +function m.test_readline() + local as = assert(async.create(64)) + local sfd = SimpleServer(as, "tcp", "127.0.0.1", 0) + local cfd = SimpleClient(as, "tcp", sfd:info "socket") + local newfd = wait_accept(as, sfd) + local rb = assert(async.readbuf(256)) + + local function recv() + lt.assertEquals(as:submit_read(rb, newfd, "recv"), true) + local _, _, status = wait_completion(as) + lt.assertEquals(status, SUCCESS) + end + + local function send(data, token) + local w = assert(async.writebuf(64 * 1024)) + w:write(data) + lt.assertEquals(as:submit_write(w, cfd, token), true) + wait_completion(as) + end + + -- 默认分隔符 \r\n + send("hello\r\nworld\r\n", "s1") + recv() + lt.assertEquals(rb:readline(), "hello\r\n") + lt.assertEquals(rb:readline(), "world\r\n") + lt.assertEquals(rb:readline(), nil) + + -- 自定义分隔符 \n + send("foo\nbar\n", "s2") + recv() + lt.assertEquals(rb:readline("\n"), "foo\n") + lt.assertEquals(rb:readline("\n"), "bar\n") + lt.assertEquals(rb:readline("\n"), nil) + + -- 多字节自定义分隔符 + send("a|b|c|b|", "s3") + recv() + lt.assertEquals(rb:readline("|b|"), "a|b|") + lt.assertEquals(rb:readline("|b|"), "c|b|") + lt.assertEquals(rb:readline("|b|"), nil) + + -- 不完整行返回 nil,read() 仍可取出 + send("no-sep", "s4") + recv() + lt.assertEquals(rb:readline(), nil) + lt.assertEquals(rb:read(), "no-sep") + + newfd:close() +end + +if platform.os == "windows" then + --- 测试 associate 和 cancel(仅 Windows) + function m.test_associate_cancel() + local as = assert(async.create(64)) + local sfd = assert(socket.create("tcp")) + -- associate 应该成功 + lt.assertEquals(as:associate(sfd), true) + -- 重复 associate 同一个 socket 也应该成功(已关联) + lt.assertEquals(as:associate(sfd), true) + + assert(sfd:bind("127.0.0.1", 0)) + assert(sfd:listen()) + local _, port = sfd:info "socket":value() + + local cfd = assert(socket.create "tcp") + lt.assertEquals(as:associate(cfd), true) + lt.assertEquals(as:submit_connect(cfd, "127.0.0.1", port, "connect_tok"), true) + + -- cancel 不应该崩溃 + as:cancel(cfd) + + -- 等待 connect 完成(可能是成功或取消) + wait_completion(as) + end +end + +--- 测试 writebuf 多段发送 +function m.test_writebuf() + local as = assert(async.create(64)) + local sfd = SimpleServer(as, "tcp", "127.0.0.1", 0) + local cfd = SimpleClient(as, "tcp", sfd:info "socket") + local newfd = wait_accept(as, sfd) + + -- 多段写入同一个 writebuf,单次 submit_write 发完,token 为 table + local wb = assert(async.writebuf(64 * 1024)) + local parts = { "hel", "lo,", " wo", "rld" } + local expected = table.concat(parts) + for _, s in ipairs(parts) do wb:write(s) end + local write_tok = { id = "wb_tok" } + lt.assertEquals(as:submit_write(wb, cfd, write_tok), true) + local _, token, status = wait_completion(as) + lt.assertEquals(token, write_tok) + lt.assertEquals(status, SUCCESS) + + -- 验证数据完整性:可能分多次收到(如 BSD/kqueue 逐段发送) + local rb = assert(async.readbuf(64)) + local received = 0 + while received < #expected do + lt.assertEquals(as:submit_read(rb, newfd, "read_tok"), true) + local _, _, rstatus, rbytes = wait_completion(as) + lt.assertEquals(rstatus, SUCCESS) + received = received + rbytes + end + lt.assertEquals(rb:read(#expected), expected) + + -- wb:buffered() + local wb2 = assert(async.writebuf(64 * 1024)) + lt.assertEquals(wb2:buffered(), 0) + wb2:write("ping") + lt.assertEquals(wb2:buffered(), 4) + + -- wb:write 背压:hwm=4, 写入4字节应返回 true + local wb3 = assert(async.writebuf(4)) + lt.assertEquals(wb3:write("abcd"), true) -- buffered(4) >= hwm(4) + lt.assertEquals(wb3:write("x"), true) -- still >= hwm + + newfd:close() +end + +--- 测试 submit_poll:只监听 fd 可读性,不消费数据,token 为字符串 +function m.test_submit_poll() + local as = assert(async.create(64)) + local sfd = SimpleServer(as, "tcp", "127.0.0.1", 0) + local cfd = SimpleClient(as, "tcp", sfd:info "socket") + local newfd = wait_accept(as, sfd) + + -- 对服务端 fd 提交 poll 请求 + local poll_tok = "poll" + lt.assertEquals(as:submit_poll(newfd, poll_tok), true) + + -- 客户端发送数据,触发服务端 fd 可读 + local wb = assert(async.writebuf(64 * 1024)) + wb:write("poll_test") + local write_tok = "write" + lt.assertEquals(as:submit_write(wb, cfd, write_tok), true) + + -- 收集 write 和 poll 两个 completions,顺序不确定 + local results = {} + local timeout = 1000 + local start = time.monotonic() + while #results < 2 and time.monotonic() - start < timeout do + for op, token, status, bytes, errcode in as:wait(100) do + results[#results+1] = { op = op, token = token, status = status, bytes = bytes, errcode = errcode } + end + end + lt.assertEquals(#results, 2) + + -- 按 token 找到 poll completion + local poll_result + for _, r in ipairs(results) do + if r.token == poll_tok then + poll_result = r + break + end + end + lt.assertIsTable(poll_result) + lt.assertEquals(poll_result.status, SUCCESS) + lt.assertEquals(poll_result.bytes, 0) -- poll 不消费数据,bytes 为 0 + lt.assertEquals(poll_result.errcode, 0) + + -- 验证数据未被消费:仍然可以正常读取 + local rb = assert(async.readbuf(64)) + local read_tok = "read" + lt.assertEquals(as:submit_read(rb, newfd, read_tok), true) + local _, token, status, bytes = wait_completion(as) + lt.assertEquals(token, read_tok) + lt.assertEquals(status, SUCCESS) + lt.assertEquals(bytes, 9) -- "poll_test" 长度为 9 + lt.assertEquals(rb:read(9), "poll_test") + + newfd:close() +end + +--- 测试 submit_poll 配合 channel fd +function m.test_submit_poll_channel() + local channel = require "bee.channel" + channel.create "poll_test_ch" + local ch = channel.query "poll_test_ch" + + local as = assert(async.create(64)) + assert(as:associate(ch:fd())) + + -- 提交 poll 请求,token 为 table + local poll_tok = { ch = "poll_test_ch" } + lt.assertEquals(as:submit_poll(ch:fd(), poll_tok), true) + + -- 向 channel push 数据,触发 fd 可读 + ch:push("hello", 42) + + -- 等待 poll completion + local got_op, got_tok, got_status, got_bytes, got_errcode + local deadline2 = time.monotonic() + 1000 + while not got_op and time.monotonic() < deadline2 do + for op, token, status, bytes, errcode in as:wait(100) do + got_op = op + got_tok = token + got_status = status + got_bytes = bytes + got_errcode = errcode + end + end + lt.assertEquals(got_tok, poll_tok) + lt.assertEquals(got_status, SUCCESS) + lt.assertEquals(got_bytes, 0) + lt.assertEquals(got_errcode, 0) + + -- 验证数据未被消费:channel pop 仍然可以取到数据 + local ok, msg, num = ch:pop() + lt.assertEquals(ok, true) + lt.assertEquals(msg, "hello") + lt.assertEquals(num, 42) + + channel.destroy "poll_test_ch" +end + +--- 测试 stop +function m.test_stop() + local as = assert(async.create(64)) + lt.assertEquals(as:stop(), true) +end + +--- 测试 close 标记(__close 元方法) +function m.test_close() + do + local as = async.create(64) + -- 离开作用域时应该自动调用 stop + end + -- 没有崩溃即通过 +end + +--- 测试 submit_readv:ring_buf 回绕时一次提交两段 +-- 场景:先收满 (cap-2) 字节使 tail 接近末尾,消费后再用 submit_readv +-- 此时 write_ptr 指向末尾偏移 (cap-2),write_len=2,free_cap=cap; +-- submit_read 会构造两段 iobuf:第一段 2 字节到缓冲区末尾, +-- 第二段 (cap-2) 字节回绕到缓冲区头部,覆盖全部空闲空间。 +function m.test_readv_wraparound() + local as = assert(async.create(64)) + local sfd = SimpleServer(as, "tcp", "127.0.0.1", 0) + local cfd = SimpleClient(as, "tcp", sfd:info "socket") + local newfd = wait_accept(as, sfd) + + -- cap=16,先收 14 字节(cap-2)使 tail=14 + -- 先提交读再发送,确保不与后续 send 合并到同一个 TCP segment。 + local rb = assert(async.readbuf(16)) + lt.assertEquals(as:submit_read(rb, newfd, "fill"), true) + local fill_data = string.rep("x", 14) + cfd:send(fill_data) + local _, _, s1, b1 = wait_completion(as) + lt.assertEquals(s1, SUCCESS) + -- 可能分批到达,循环直到收到全部 14 字节 + local got = b1 + while got < 14 do + lt.assertEquals(as:submit_read(rb, newfd, "fill"), true) + local _, _, s2, b2 = wait_completion(as) + lt.assertEquals(s2, SUCCESS) + got = got + b2 + end + -- 消费掉全部数据:head 追上 tail(=14),缓冲区清空 + rb:read(14) + + -- 现在 tail=14, head=14,write_ptr 在偏移 14(=14 & 15),write_len=2,free_cap=16 + -- submit_read 会用两段:[offset14..15] + [offset0..13] + -- 先提交读再发送,确保 submit_read 时 ring_buf 处于回绕状态 + local payload = string.rep("A", 2) .. string.rep("B", 14) -- 16 字节,填满两段 + lt.assertEquals(as:submit_read(rb, newfd, "readv_tok"), true) + cfd:send(payload) + local op, tok, rs, rb2 = wait_completion(as) + lt.assertEquals(tok, "readv_tok") + lt.assertEquals(rs, SUCCESS) + lt.assertEquals(op, async.OP_READV) + -- 收到的字节数应等于发送量(16 字节 ≤ free_cap=16) + lt.assertEquals(rb2 > 0, true) + + -- 数据完整性:循环读取直到收到全部 16 字节 + local received = rb2 + while received < 16 do + lt.assertEquals(as:submit_read(rb, newfd, "readv_tok2"), true) + local _, _, rs2, rb3 = wait_completion(as) + lt.assertEquals(rs2, SUCCESS) + received = received + rb3 + end + lt.assertEquals(rb:read(16), payload) + + newfd:close() +end + +--- 测试 writev 批量提交:多个 entry 一次 submit_write,只产生一次 completion +function m.test_writev_batch() + local as = assert(async.create(64)) + local sfd = SimpleServer(as, "tcp", "127.0.0.1", 0) + local cfd = SimpleClient(as, "tcp", sfd:info "socket") + local newfd = wait_accept(as, sfd) + + -- 多个小 entry 写入同一个 writebuf,一次提交应合并为单次 OP_WRITEV completion + local wb = assert(async.writebuf(64 * 1024)) + local parts = { "aaa", "bbb", "ccc", "ddd", "eee" } + local expected = table.concat(parts) + for _, s in ipairs(parts) do wb:write(s) end + + local write_tok = "batch_write" + lt.assertEquals(as:submit_write(wb, cfd, write_tok), true) + local op, token, wstatus = wait_completion(as) + lt.assertEquals(token, write_tok) + lt.assertEquals(wstatus, SUCCESS) + lt.assertEquals(op, async.OP_WRITEV) + + -- 验证数据完整性:循环读取直到收满 #expected 字节 + local rb = assert(async.readbuf(64)) + local received = 0 + while received < #expected do + lt.assertEquals(as:submit_read(rb, newfd, "r"), true) + local _, _, rstatus, rbytes = wait_completion(as) + lt.assertEquals(rstatus, SUCCESS) + received = received + rbytes + end + lt.assertEquals(rb:read(#expected), expected) + + newfd:close() +end