33#ifdef _WIN32
44#include < winsock2.h> // closesocket, recv, send, socket
55#else
6+ #include < fcntl.h> // fcntl, F_GETFL, F_SETFL, O_NONBLOCK
67#include < netdb.h> // addrinfo, gai_strerror, getaddrinfo, freeaddrinfo
78#include < netinet/in.h> // htons, IPPROTO_TCP
89#include < sys/poll.h> // pollfd
9- #include < sys/socket.h> // AF_INET, connect, recv, send, sockaddr, sockaddr_in, socket, SOCK_STREAM
10- #include < unistd.h> // close, ssize_t
10+ #include < sys/socket.h> // AF_INET, connect, recv, send, sockaddr, sockaddr_in, socket, SOCK_STREAM, getsockopt, SO_ERROR, SOL_SOCKET
11+ #include < unistd.h> // close, ssize_t
1112
1213#include < cerrno> // errno
1314#endif
@@ -31,6 +32,55 @@ int GetErrNo() {
3132 return errno;
3233#endif
3334}
35+
36+ int Poll (::pollfd* fds, std::uint32_t nfds, int timeout_ms) {
37+ #ifdef _WIN32
38+ return ::WSAPoll (fds, nfds, timeout_ms);
39+ #else
40+ return ::poll (fds, static_cast <::nfds_t >(nfds), timeout_ms);
41+ #endif
42+ }
43+
44+ #ifdef _WIN32
45+ constexpr int kConnectInProgress = WSAEWOULDBLOCK;
46+ #else
47+ constexpr int kConnectInProgress = EINPROGRESS;
48+ #endif
49+
50+ // Saves the current blocking state, sets non-blocking, and returns a RAII guard
51+ // that restores the original state on destruction.
52+ struct BlockingGuard {
53+ databento::detail::Socket fd;
54+ #ifdef _WIN32
55+ // No state to save on Windows
56+ #else
57+ int original_flags;
58+ #endif
59+
60+ explicit BlockingGuard (databento::detail::Socket fd) : fd{fd} {
61+ #ifdef _WIN32
62+ unsigned long mode = 1 ;
63+ ::ioctlsocket (fd, FIONBIO, &mode);
64+ #else
65+ original_flags = ::fcntl (fd, F_GETFL, 0 );
66+ ::fcntl (fd, F_SETFL, original_flags | O_NONBLOCK);
67+ #endif
68+ }
69+
70+ ~BlockingGuard () {
71+ #ifdef _WIN32
72+ unsigned long mode = 0 ;
73+ ::ioctlsocket (fd, FIONBIO, &mode);
74+ #else
75+ ::fcntl (fd, F_SETFL, original_flags);
76+ #endif
77+ }
78+
79+ BlockingGuard (const BlockingGuard&) = delete ;
80+ BlockingGuard& operator =(const BlockingGuard&) = delete ;
81+ BlockingGuard (BlockingGuard&&) = delete ;
82+ BlockingGuard& operator =(BlockingGuard&&) = delete ;
83+ };
3484} // namespace
3585
3686TcpClient::TcpClient (ILogReceiver* log_receiver, const std::string& gateway,
@@ -83,12 +133,7 @@ databento::IReadable::Result TcpClient::ReadSome(std::byte* buffer,
83133 // having no timeout
84134 const auto timeout_ms = timeout.count () ? static_cast <int >(timeout.count ()) : -1 ;
85135 while (true ) {
86- const int poll_status =
87- #ifdef _WIN32
88- ::WSAPoll (&fds, 1 , timeout_ms);
89- #else
90- ::poll (&fds, 1 , timeout_ms);
91- #endif
136+ const int poll_status = Poll (&fds, 1 , timeout_ms);
92137 if (poll_status > 0 ) {
93138 return ReadSome (buffer, max_size);
94139 }
@@ -130,13 +175,34 @@ databento::detail::ScopedFd TcpClient::InitSocket(ILogReceiver* log_receiver,
130175 }
131176 std::unique_ptr<addrinfo, decltype (&::freeaddrinfo)> res{out, &::freeaddrinfo};
132177 const auto max_attempts = std::max<std::uint32_t >(retry_conf.max_attempts , 1 );
178+ const auto timeout_ms = static_cast <int >(
179+ std::chrono::duration_cast<std::chrono::milliseconds>(retry_conf.connect_timeout )
180+ .count ());
133181 std::chrono::seconds backoff{1 };
134182 for (std::uint32_t attempt = 0 ; attempt < max_attempts; ++attempt) {
135- if (::connect (scoped_fd.Get (), res->ai_addr , res->ai_addrlen ) == 0 ) {
183+ BlockingGuard guard{scoped_fd.Get ()};
184+
185+ const int connect_ret = ::connect (scoped_fd.Get (), res->ai_addr , res->ai_addrlen );
186+ bool connected = (connect_ret == 0 );
187+ if (!connected && ::GetErrNo () == kConnectInProgress ) {
188+ pollfd pfd{scoped_fd.Get (), POLLOUT, {}};
189+ const int poll_ret = Poll (&pfd, 1 , timeout_ms);
190+ if (poll_ret > 0 ) {
191+ int so_error = 0 ;
192+ socklen_t len = sizeof (so_error);
193+ ::getsockopt (scoped_fd.Get(), SOL_SOCKET, SO_ERROR, &so_error, &len);
194+ connected = (so_error == 0 );
195+ if (!connected) {
196+ errno = so_error;
197+ }
198+ }
199+ }
200+
201+ if (connected) {
136202 break ;
137203 } else if (attempt + 1 == max_attempts) {
138204 std::ostringstream err_msg;
139- err_msg << " Socket failed to connect after " << max_attempts << " attempts " ;
205+ err_msg << " Socket failed to connect after " << max_attempts << " attempt(s) " ;
140206 throw TcpError{::GetErrNo (), err_msg.str ()};
141207 }
142208 std::ostringstream log_msg;
0 commit comments