diff --git a/CMakeLists.txt b/CMakeLists.txt index a987ff2..0c6de49 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -218,6 +218,8 @@ target_link_libraries( fc PUBLIC easylzma_static ${Boost_LIBRARIES} ${OPENSSL_LI #target_link_libraries( test_aes fc ${rt_library} ${pthread_library} ) #add_executable( test_sleep tests/sleep.cpp ) #target_link_libraries( test_sleep fc ) +add_executable( test_rate_limiting tests/rate_limiting.cpp ) +target_link_libraries( test_rate_limiting fc ) if(WIN32) # add addtional import library on windows platform diff --git a/include/fc/network/rate_limiting.hpp b/include/fc/network/rate_limiting.hpp index 10ac9bb..7694789 100644 --- a/include/fc/network/rate_limiting.hpp +++ b/include/fc/network/rate_limiting.hpp @@ -1,8 +1,7 @@ #pragma once -#include -#include -#include -#include +#include + +#include namespace fc { @@ -11,12 +10,22 @@ namespace fc class rate_limiting_group_impl; } + class tcp_socket; + class rate_limiting_group { public: rate_limiting_group(uint32_t upload_bytes_per_second, uint32_t download_bytes_per_second); ~rate_limiting_group(); + void set_upload_limit(uint32_t upload_bytes_per_second); + uint32_t get_upload_limit() const; + + void set_download_limit(uint32_t download_bytes_per_second); + uint32_t get_download_limit() const; + + void add_tcp_socket(tcp_socket* tcp_socket_to_limit); + void remove_tcp_socket(tcp_socket* tcp_socket_to_stop_limiting); private: std::unique_ptr my; }; diff --git a/include/fc/network/tcp_socket.hpp b/include/fc/network/tcp_socket.hpp index 3229f87..70fb9d8 100644 --- a/include/fc/network/tcp_socket.hpp +++ b/include/fc/network/tcp_socket.hpp @@ -6,6 +6,9 @@ namespace fc { namespace ip { class endpoint; } + + class tcp_socket_io_hooks; + class tcp_socket : public virtual iostream { public: @@ -15,6 +18,7 @@ namespace fc { void connect_to( const fc::ip::endpoint& remote_endpoint ); void connect_to( const fc::ip::endpoint& remote_endpoint, const fc::ip::endpoint& local_endpoint ); void enable_keep_alives(const fc::microseconds& interval); + void set_io_hooks(tcp_socket_io_hooks* new_hooks); fc::ip::endpoint remote_endpoint()const; void get( char& c ) @@ -41,9 +45,9 @@ namespace fc { friend class tcp_server; class impl; #ifdef _WIN64 - fc::fwd my; + fc::fwd my; #else - fc::fwd my; + fc::fwd my; #endif }; typedef std::shared_ptr tcp_socket_ptr; diff --git a/include/fc/network/tcp_socket_io_hooks.hpp b/include/fc/network/tcp_socket_io_hooks.hpp new file mode 100644 index 0000000..a317ed1 --- /dev/null +++ b/include/fc/network/tcp_socket_io_hooks.hpp @@ -0,0 +1,12 @@ +#include + +namespace fc +{ + class tcp_socket_io_hooks + { + public: + virtual ~tcp_socket_io_hooks() {} + virtual size_t readsome(boost::asio::ip::tcp::socket& socket, char* buffer, size_t length) = 0; + virtual size_t writesome(boost::asio::ip::tcp::socket& socket, const char* buffer, size_t length) = 0; + }; +} // namesapce fc diff --git a/src/network/rate_limiting.cpp b/src/network/rate_limiting.cpp index 6182169..f41eaa0 100644 --- a/src/network/rate_limiting.cpp +++ b/src/network/rate_limiting.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -6,6 +7,7 @@ #include #include #include +#include namespace fc { @@ -82,7 +84,7 @@ namespace fc } }; - class rate_limiting_group_impl + class rate_limiting_group_impl : public tcp_socket_io_hooks { public: uint32_t _upload_bytes_per_second; @@ -96,16 +98,22 @@ namespace fc rate_limited_operation_list _write_operations_in_progress; rate_limited_operation_list _write_operations_for_next_iteration; + time_point _last_read_iteration_time; time_point _last_write_iteration_time; + fc::future _process_pending_reads_loop_complete; + fc::future _process_pending_writes_loop_complete; + rate_limiting_group_impl(uint32_t upload_bytes_per_second, uint32_t download_bytes_per_second); - size_t readsome(boost::asio::ip::tcp::socket& socket, char* buffer, size_t length); - size_t writesome(boost::asio::ip::tcp::socket& socket, const char* buf, size_t len); + virtual size_t readsome(boost::asio::ip::tcp::socket& socket, char* buffer, size_t length) override; + virtual size_t writesome(boost::asio::ip::tcp::socket& socket, const char* buffer, size_t length) override; void process_pending_reads(); void process_pending_writes(); - void process_pending_operations(rate_limited_operation_list& operations_in_progress, + void process_pending_operations(time_point& last_iteration_start_time, + uint32_t& limit_bytes_per_second, + rate_limited_operation_list& operations_in_progress, rate_limited_operation_list& operations_for_next_iteration); }; @@ -118,9 +126,16 @@ namespace fc size_t rate_limiting_group_impl::readsome(boost::asio::ip::tcp::socket& socket, char* buffer, size_t length) { - promise::ptr completion_promise(new promise()); - _read_operations_for_next_iteration.emplace_back(std::make_unique(socket, buffer, length, completion_promise)); - return completion_promise->wait(); + if (_download_bytes_per_second) + { + promise::ptr completion_promise(new promise()); + _read_operations_for_next_iteration.emplace_back(std::make_unique(socket, buffer, length, completion_promise)); + if (!_process_pending_reads_loop_complete.valid()) + _process_pending_reads_loop_complete = async([=](){ process_pending_reads(); }); + return completion_promise->wait(); + } + else + return asio::read_some(socket, boost::asio::buffer(buffer, length)); } size_t rate_limiting_group_impl::writesome(boost::asio::ip::tcp::socket& socket, const char* buffer, size_t length) { @@ -128,6 +143,8 @@ namespace fc { promise::ptr completion_promise(new promise()); _write_operations_for_next_iteration.emplace_back(std::make_unique(socket, buffer, length, completion_promise)); + if (!_process_pending_writes_loop_complete.valid()) + _process_pending_writes_loop_complete = async([=](){ process_pending_writes(); }); return completion_promise->wait(); } else @@ -135,13 +152,25 @@ namespace fc } void rate_limiting_group_impl::process_pending_reads() { - process_pending_operations(_read_operations_in_progress, _read_operations_for_next_iteration); + for (;;) + { + process_pending_operations(_last_read_iteration_time, _download_bytes_per_second, + _read_operations_in_progress, _read_operations_for_next_iteration); + fc::usleep(_granularity); + } } void rate_limiting_group_impl::process_pending_writes() { - process_pending_operations(_write_operations_in_progress, _write_operations_for_next_iteration); + for (;;) + { + process_pending_operations(_last_write_iteration_time, _upload_bytes_per_second, + _write_operations_in_progress, _write_operations_for_next_iteration); + fc::usleep(_granularity); + } } - void rate_limiting_group_impl::process_pending_operations(rate_limited_operation_list& operations_in_progress, + void rate_limiting_group_impl::process_pending_operations(time_point& last_iteration_start_time, + uint32_t& limit_bytes_per_second, + rate_limited_operation_list& operations_in_progress, rate_limited_operation_list& operations_for_next_iteration) { // lock here for multithreaded @@ -150,39 +179,40 @@ namespace fc std::back_inserter(operations_in_progress)); operations_for_next_iteration.clear(); - // find out how much time since our last write - time_point this_write_iteration_start_time = time_point::now(); - if (_upload_bytes_per_second) // the we are limiting upload speed + // find out how much time since our last read/write + time_point this_iteration_start_time = time_point::now(); + if (limit_bytes_per_second) // the we are limiting up/download speed { - microseconds time_since_last_iteration = this_write_iteration_start_time - _last_write_iteration_time; + microseconds time_since_last_iteration = this_iteration_start_time - last_iteration_start_time; if (time_since_last_iteration > seconds(1)) time_since_last_iteration = seconds(1); else if (time_since_last_iteration < microseconds(0)) time_since_last_iteration = microseconds(0); - uint32_t total_bytes_for_this_iteration = - (uint32_t)(time_since_last_iteration.count() / _upload_bytes_per_second / seconds(1).count()); + uint32_t total_bytes_for_this_iteration = time_since_last_iteration.count() ? + (uint32_t)((1000000 * limit_bytes_per_second) / time_since_last_iteration.count()) : + 0; if (total_bytes_for_this_iteration) { - // sort the pending writes in order of the number of bytes they need to write, smallest first + // sort the pending reads/writes in order of the number of bytes they need to write, smallest first std::vector operations_sorted_by_length; operations_sorted_by_length.reserve(operations_in_progress.size()); for (std::unique_ptr& operation_data : operations_in_progress) operations_sorted_by_length.push_back(operation_data.get()); std::sort(operations_sorted_by_length.begin(), operations_sorted_by_length.end(), is_operation_shorter()); - // figure out how many bytes each writer is allowed to write + // figure out how many bytes each reader/writer is allowed to read/write uint32_t bytes_remaining_to_allocate = total_bytes_for_this_iteration; while (!operations_sorted_by_length.empty()) { - uint32_t bytes_permitted_for_this_writer = bytes_remaining_to_allocate / operations_sorted_by_length.size(); - uint32_t bytes_allocated_for_this_writer = std::min(operations_sorted_by_length.back()->length, bytes_permitted_for_this_writer); - operations_sorted_by_length.back()->permitted_length = bytes_allocated_for_this_writer; - bytes_remaining_to_allocate -= bytes_allocated_for_this_writer; + uint32_t bytes_permitted_for_this_operation = bytes_remaining_to_allocate / operations_sorted_by_length.size(); + uint32_t bytes_allocated_for_this_operation = std::min(operations_sorted_by_length.back()->length, bytes_permitted_for_this_operation); + operations_sorted_by_length.back()->permitted_length = bytes_allocated_for_this_operation; + bytes_remaining_to_allocate -= bytes_allocated_for_this_operation; operations_sorted_by_length.pop_back(); } - // kick off the writes in first-come order + // kick off the reads/writes in first-come order for (auto iter = operations_in_progress.begin(); iter != operations_in_progress.end();) { if ((*iter)->permitted_length > 0) @@ -209,7 +239,7 @@ namespace fc } operations_in_progress.clear(); } - _last_write_iteration_time = this_write_iteration_start_time; + last_iteration_start_time = this_iteration_start_time; } } @@ -223,6 +253,35 @@ namespace fc { } + void rate_limiting_group::set_upload_limit(uint32_t upload_bytes_per_second) + { + my->_upload_bytes_per_second = upload_bytes_per_second; + } + + uint32_t rate_limiting_group::get_upload_limit() const + { + return my->_upload_bytes_per_second; + } + + void rate_limiting_group::set_download_limit(uint32_t download_bytes_per_second) + { + my->_download_bytes_per_second = download_bytes_per_second; + } + + uint32_t rate_limiting_group::get_download_limit() const + { + return my->_download_bytes_per_second; + } + + void rate_limiting_group::add_tcp_socket(tcp_socket* tcp_socket_to_limit) + { + tcp_socket_to_limit->set_io_hooks(my.get()); + } + + void rate_limiting_group::remove_tcp_socket(tcp_socket* tcp_socket_to_stop_limiting) + { + tcp_socket_to_stop_limiting->set_io_hooks(NULL); + } } // namespace fc diff --git a/src/network/tcp_socket.cpp b/src/network/tcp_socket.cpp index 9212419..e588be1 100644 --- a/src/network/tcp_socket.cpp +++ b/src/network/tcp_socket.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -12,14 +13,32 @@ namespace fc { - class tcp_socket::impl { + class tcp_socket::impl : public tcp_socket_io_hooks{ public: - impl():_sock( fc::asio::default_io_service() ){} - ~impl(){ + impl() : + _sock(fc::asio::default_io_service()), + _io_hooks(this) + {} + ~impl() + { if( _sock.is_open() ) _sock.close(); } + virtual size_t readsome(boost::asio::ip::tcp::socket& socket, char* buffer, size_t length) override; + virtual size_t writesome(boost::asio::ip::tcp::socket& socket, const char* buffer, size_t length) override; + boost::asio::ip::tcp::socket _sock; + tcp_socket_io_hooks* _io_hooks; }; + + size_t tcp_socket::impl::readsome(boost::asio::ip::tcp::socket& socket, char* buffer, size_t length) + { + return fc::asio::read_some(socket, boost::asio::buffer(buffer, length)); + } + size_t tcp_socket::impl::writesome(boost::asio::ip::tcp::socket& socket, const char* buffer, size_t length) + { + return fc::asio::write_some(socket, boost::asio::buffer(buffer, length)); + } + bool tcp_socket::is_open()const { return my->_sock.is_open(); } @@ -42,8 +61,8 @@ namespace fc { return !my->_sock.is_open(); } - size_t tcp_socket::writesome( const char* buf, size_t len ) { - return fc::asio::write_some( my->_sock, boost::asio::buffer( buf, len ) ); + size_t tcp_socket::writesome(const char* buf, size_t len) { + return my->_io_hooks->writesome(my->_sock, buf, len); } fc::ip::endpoint tcp_socket::remote_endpoint()const @@ -52,9 +71,8 @@ namespace fc { return fc::ip::endpoint(rep.address().to_v4().to_ulong(), rep.port() ); } - size_t tcp_socket::readsome( char* buf, size_t len ) { - auto r = fc::asio::read_some( my->_sock, boost::asio::buffer( buf, len ) ); - return r; + size_t tcp_socket::readsome(char* buf, size_t len) { + return my->_io_hooks->readsome(my->_sock, buf, len); } void tcp_socket::connect_to( const fc::ip::endpoint& remote_endpoint ) { @@ -107,6 +125,11 @@ namespace fc { } } + void tcp_socket::set_io_hooks(tcp_socket_io_hooks* new_hooks) + { + my->_io_hooks = new_hooks ? new_hooks : &*my; + } + class tcp_server::impl { public: impl(uint16_t port) diff --git a/tests/rate_limiting.cpp b/tests/rate_limiting.cpp new file mode 100644 index 0000000..1ca31f3 --- /dev/null +++ b/tests/rate_limiting.cpp @@ -0,0 +1,24 @@ +#include +#include +#include +#include +#include + +#include + +int main( int argc, char** argv ) +{ + fc::rate_limiting_group rate_limiter(1000000,1000000); + fc::http::connection http_connection; + rate_limiter.add_tcp_socket(&http_connection.get_socket()); + http_connection.connect_to(fc::ip::endpoint(fc::ip::address("162.243.115.24"),80)); + std::cout << "Starting download...\n"; + fc::time_point start_time(fc::time_point::now()); + fc::http::reply reply = http_connection.request("GET", "http://invictus.io/bin/Keyhotee-0.7.0.dmg"); + fc::time_point end_time(fc::time_point::now()); + + std::cout << "HTTP return code: " << reply.status << "\n"; + std::cout << "Retreived " << reply.body.size() << " bytes in " << ((end_time - start_time).count() / fc::milliseconds(1).count()) << "ms\n"; + std::cout << "Average speed " << ((1000 * (uint64_t)reply.body.size()) / ((end_time - start_time).count() / fc::milliseconds(1).count())) << " bytes per second"; + return 0; +}