diff --git a/CMakeLists.txt b/CMakeLists.txt index 281e6f8..982ff26 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -183,7 +183,7 @@ list(APPEND sources "${CMAKE_CURRENT_BINARY_DIR}/git_revision.cpp") list(APPEND sources ${fc_headers}) add_subdirectory( vendor/easylzma ) -#add_subdirectory( vendor/udt4 ) +add_subdirectory( vendor/udt4 ) setup_library( fc SOURCES ${sources} LIBRARY_TYPE STATIC DONT_INSTALL_LIBRARY ) @@ -228,11 +228,11 @@ target_link_libraries( ntp_test fc ) #include_directories( vendor/udt4/src ) -#add_executable( udt_server tests/udt_server.cpp ) -#target_link_libraries( udt_server fc udt ) +add_executable( udt_server tests/udts.cpp ) +target_link_libraries( udt_server fc udt ) -#add_executable( udt_client tests/udt_client.cpp ) -#target_link_libraries( udt_client fc udt ) +add_executable( udt_client tests/udtc.cpp ) +target_link_libraries( udt_client fc udt ) #add_executable( test_compress tests/compress.cpp ) #target_link_libraries( test_compress fc ) diff --git a/include/fc/network/udt_socket.hpp b/include/fc/network/udt_socket.hpp index b022b93..aa1dd49 100644 --- a/include/fc/network/udt_socket.hpp +++ b/include/fc/network/udt_socket.hpp @@ -3,17 +3,19 @@ #include #include #include +#include namespace fc { namespace ip { class endpoint; } - class udt_socket : public virtual iostream + class udt_socket : public virtual iostream, public noncopyable { public: udt_socket(); ~udt_socket(); - void connect_to( const fc::ip::endpoint& remote_endpoint ); + void bind( const fc::ip::endpoint& local_endpoint ); + void connect_to( const fc::ip::endpoint& remote_endpoint ); fc::ip::endpoint remote_endpoint() const; fc::ip::endpoint local_endpoint() const; @@ -46,4 +48,20 @@ namespace fc { }; typedef std::shared_ptr udt_socket_ptr; + class udt_server : public noncopyable + { + public: + udt_server(); + ~udt_server(); + + void close(); + void accept( udt_socket& s ); + + void listen( const fc::ip::endpoint& ep ); + fc::ip::endpoint local_endpoint() const; + + private: + int _udt_socket_id; + }; + } // fc diff --git a/include/fc/noncopyable.hpp b/include/fc/noncopyable.hpp new file mode 100644 index 0000000..87fad6b --- /dev/null +++ b/include/fc/noncopyable.hpp @@ -0,0 +1,14 @@ +#pragma once + +namespace fc +{ + class noncopyable + { + public: + noncopyable(){} + private: + noncopyable( const noncopyable& ) = delete; + noncopyable& operator=( const noncopyable& ) = delete; + }; +} + diff --git a/src/network/udt_socket.cpp b/src/network/udt_socket.cpp index 7a1f197..3d23073 100644 --- a/src/network/udt_socket.cpp +++ b/src/network/udt_socket.cpp @@ -8,6 +8,7 @@ #include namespace fc { + class udt_epoll_service { @@ -33,7 +34,7 @@ namespace fc { UDT::epoll_wait( _epoll_id, &read_ready, - &write_ready, 1000 ); + &write_ready, 1000*1000 ); { synchronized(_read_promises_mutex) for( auto sock : read_ready ) @@ -101,6 +102,12 @@ namespace fc { }; + udt_epoll_service& default_epool_service() + { + static udt_epoll_service* default_service = new udt_epoll_service(); + return *default_service; + } + void check_udt_errors() { @@ -120,18 +127,40 @@ namespace fc { udt_socket::~udt_socket() { - close(); + try { + close(); + } catch ( const fc::exception& e ) + { + wlog( "${e}", ("e", e.to_detail_string() ) ); + } } + void udt_socket::bind( const fc::ip::endpoint& local_endpoint ) + { try { + if( !is_open() ) + open(); + + sockaddr_in local_addr; + local_addr.sin_family = AF_INET; + local_addr.sin_port = htons(local_endpoint.port()); + local_addr.sin_addr.s_addr = htonl(local_endpoint.get_address()); + + if( UDT::ERROR == UDT::bind(_udt_socket_id, (sockaddr*)&local_addr, sizeof(local_addr)) ) + check_udt_errors(); + } FC_CAPTURE_AND_RETHROW() } + void udt_socket::connect_to( const ip::endpoint& remote_endpoint ) { try { + if( !is_open() ) + open(); + sockaddr_in serv_addr; serv_addr.sin_family = AF_INET; serv_addr.sin_port = htons(remote_endpoint.port()); serv_addr.sin_addr.s_addr = htonl(remote_endpoint.get_address()); // connect to the server, implict bind - if (UDT::ERROR == UDT::connect(_udt_socket_id, (sockaddr*)&serv_addr, sizeof(serv_addr))) + if( UDT::ERROR == UDT::connect(_udt_socket_id, (sockaddr*)&serv_addr, sizeof(serv_addr)) ) check_udt_errors(); } FC_CAPTURE_AND_RETHROW( (remote_endpoint) ) } @@ -165,8 +194,11 @@ namespace fc { { if( UDT::getlasterror().getErrorCode() == CUDTException::EASYNCRCV ) { - // create a future and post to epoll, wait on it, then - // call readsome recursively. + UDT::getlasterror().clear(); + promise::ptr p(new promise("udt_socket::readsome")); + default_epool_service().notify_read( _udt_socket_id, p ); + p->wait(); + return readsome( buffer, max ); } else check_udt_errors(); @@ -188,7 +220,18 @@ namespace fc { auto bytes_sent = UDT::send(_udt_socket_id, buffer, len, 0); if( UDT::ERROR == bytes_sent ) - check_udt_errors(); + { + if( UDT::getlasterror().getErrorCode() == CUDTException::EASYNCRCV ) + { + UDT::getlasterror().clear(); + promise::ptr p(new promise("udt_socket::writesome")); + default_epool_service().notify_write( _udt_socket_id, p ); + p->wait(); + return writesome( buffer, len ); + } + else + check_udt_errors(); + } if( bytes_sent == 0 ) { @@ -217,4 +260,81 @@ namespace fc { } + + + + + udt_server::udt_server() + :_udt_socket_id( UDT::INVALID_SOCK ) + { + _udt_socket_id = UDT::socket(AF_INET, SOCK_STREAM, 0); + bool block = false; + UDT::setsockopt(_udt_socket_id, 0, UDT_SNDSYN, &block, sizeof(bool)); + UDT::setsockopt(_udt_socket_id, 0, UDT_RCVSYN, &block, sizeof(bool)); + } + + udt_server::~udt_server() + { + try { + close(); + } catch ( const fc::exception& e ) + { + wlog( "${e}", ("e", e.to_detail_string() ) ); + } + } + + void udt_server::close() + { try { + UDT::close( _udt_socket_id ); + check_udt_errors(); + } FC_CAPTURE_AND_RETHROW() } + + void udt_server::accept( udt_socket& s ) + { try { + FC_ASSERT( !s.is_open() ); + int namelen; + sockaddr_in their_addr; + + s._udt_socket_id = UDT::accept( _udt_socket_id, (sockaddr*)&their_addr, &namelen ); + + if( s._udt_socket_id == UDT::INVALID_SOCK ) + { + if( UDT::getlasterror().getErrorCode() == CUDTException::EASYNCRCV ) + { + UDT::getlasterror().clear(); + promise::ptr p(new promise("udt_server::accept")); + default_epool_service().notify_read( _udt_socket_id, p ); + p->wait(); + this->accept(s); + } + else + check_udt_errors(); + } + } FC_CAPTURE_AND_RETHROW() } + + void udt_server::listen( const ip::endpoint& ep ) + { try { + sockaddr_in my_addr; + my_addr.sin_family = AF_INET; + my_addr.sin_port = htons(ep.port()); + my_addr.sin_addr.s_addr = INADDR_ANY; + memset(&(my_addr.sin_zero), '\0', 8); + + if( UDT::ERROR == UDT::bind(_udt_socket_id, (sockaddr*)&my_addr, sizeof(my_addr)) ) + check_udt_errors(); + + UDT::listen(_udt_socket_id, 10); + check_udt_errors(); + } FC_CAPTURE_AND_RETHROW( (ep) ) } + + fc::ip::endpoint udt_server::local_endpoint() const + { try { + sockaddr_in sock_addr; + int addr_size = sizeof(sock_addr); + int error_code = UDT::getsockname( _udt_socket_id, (struct sockaddr*)&sock_addr, &addr_size ); + if( error_code == UDT::ERROR ) + check_udt_errors(); + return ip::endpoint( ip::address( htonl( sock_addr.sin_addr.s_addr ) ), htons(sock_addr.sin_port) ); + } FC_CAPTURE_AND_RETHROW() } + } diff --git a/tests/udtc.cpp b/tests/udtc.cpp new file mode 100644 index 0000000..adfd497 --- /dev/null +++ b/tests/udtc.cpp @@ -0,0 +1,33 @@ +#include +#include +#include +#include +#include + +using namespace fc; + +int main( int argc, char** argv ) +{ + try { + udt_socket sock; + sock.bind( fc::ip::endpoint::from_string( "127.0.0.1:6666" ) ); + sock.connect_to( fc::ip::endpoint::from_string( "127.0.0.1:7777" ) ); + + std::cout << "local endpoint: " < response; + response.resize(1024); + int r = sock.readsome( response.data(), response.size() ); + + std::cout << "response: '"< +#include +#include +#include +#include + +using namespace fc; + +int main( int argc, char** argv ) +{ + try { + udt_server serv; + serv.listen( fc::ip::endpoint::from_string( "127.0.0.1:7777" ) ); + + while( true ) + { + udt_socket sock; + serv.accept( sock ); + + std::vector response; + response.resize(1024); + int r = sock.readsome( response.data(), response.size() ); + + std::cout << "request: '"<