diff --git a/include/fc/network/http/websocket.hpp b/include/fc/network/http/websocket.hpp index 5dd2a8d..0fd208b 100644 --- a/include/fc/network/http/websocket.hpp +++ b/include/fc/network/http/websocket.hpp @@ -91,9 +91,11 @@ namespace fc { namespace http { void close(); void synchronous_close(); + void append_header(const std::string& key, const std::string& value); private: std::unique_ptr my; std::unique_ptr smy; + std::vector> headers; }; class websocket_tls_client { diff --git a/src/log/logger.cpp b/src/log/logger.cpp index 93241f3..bd31afe 100644 --- a/src/log/logger.cpp +++ b/src/log/logger.cpp @@ -99,8 +99,12 @@ namespace fc { void logger::add_appender( const fc::shared_ptr& a ) { my->_appenders.push_back(a); } -// void logger::remove_appender( const fc::shared_ptr& a ) - // { my->_appenders.erase(a); } + void logger::remove_appender( const fc::shared_ptr& a ) + { + auto item = std::find(my->_appenders.begin(), my->_appenders.end(), a); + if (item != my->_appenders.end()) + my->_appenders.erase(item); + } std::vector > logger::get_appenders()const { diff --git a/src/network/http/websocket.cpp b/src/network/http/websocket.cpp index f315b7d..b09e4e3 100644 --- a/src/network/http/websocket.cpp +++ b/src/network/http/websocket.cpp @@ -165,9 +165,14 @@ namespace fc { namespace http { return _ws_connection->get_request_header(key); } + /**** + * @brief retrieves the remote hostname + * + * @param forward_header_key the key to look at in the request header + * @returns the value in the header, otherwise the remote endpoint + */ virtual std::string get_remote_hostname(const std::string& forward_header_key) { - // TODO: check headers, revert to the raw connection details if (!forward_header_key.empty()) { std::string header_value = _ws_connection->get_request_header(forward_header_key); @@ -186,17 +191,16 @@ namespace fc { namespace http { { public: websocket_server_impl(const std::string& forward_header_key = std::string() ) - :_server_thread( fc::thread::current() ) + :_server_thread( fc::thread::current() ), fwd_header_key(forward_header_key) { - _server.clear_access_channels( websocketpp::log::alevel::all ); _server.init_asio(&fc::asio::default_io_service()); _server.set_reuse_addr(true); _server.set_open_handler( [&]( connection_hdl hdl ){ - _server_thread.async( [&](){ - auto new_con = std::make_shared>( _server.get_con_from_hdl(hdl) ); - _on_connection( _connections[hdl] = new_con ); - }).wait(); + _server_thread.async( [&](){ + auto new_con = std::make_shared>( _server.get_con_from_hdl(hdl) ); + _on_connection( _connections[hdl] = new_con ); + }).wait(); }); _server.set_message_handler( [&]( connection_hdl hdl, websocket_server_type::message_ptr msg ){ _server_thread.async( [&](){ @@ -205,7 +209,7 @@ namespace fc { namespace http { auto payload = msg->get_payload(); std::shared_ptr con = current_con->second; wlog( "Websocket Server Remote: ${host} Payload: ${body}", - ("host", con->get_remote_hostname(forward_header_key)) ("body", msg->get_payload())); + ("host", con->get_remote_hostname(fwd_header_key)) ("body", msg->get_payload())); ++_pending_messages; auto f = fc::async([this,con,payload](){ if( _pending_messages ) --_pending_messages; con->on_message( payload ); }); if( _pending_messages > 100 ) @@ -297,6 +301,7 @@ namespace fc { namespace http { on_connection_handler _on_connection; fc::promise::ptr _closed; uint32_t _pending_messages = 0; + std::string fwd_header_key; }; class websocket_tls_server_impl @@ -422,7 +427,6 @@ namespace fc { namespace http { _client.set_message_handler( [&]( connection_hdl hdl, message_ptr msg ){ _client_thread.async( [&](){ wdump((msg->get_payload())); - //std::cerr<<"recv: "<get_payload()<<"\n"; auto received = msg->get_payload(); fc::async( [=](){ if( _connection ) @@ -658,7 +662,7 @@ namespace fc { namespace http { - websocket_client::websocket_client( const std::string& ca_filename ):my( new detail::websocket_client_impl() ),smy(new detail::websocket_tls_client_impl( ca_filename )) {} + websocket_client::websocket_client( const std::string& ca_filename):my( new detail::websocket_client_impl() ),smy(new detail::websocket_tls_client_impl( ca_filename )) {} websocket_client::~websocket_client(){ } websocket_connection_ptr websocket_client::connect( const std::string& uri ) @@ -667,7 +671,6 @@ namespace fc { namespace http { return secure_connect(uri); FC_ASSERT( uri.substr(0,3) == "ws:" ); - // wlog( "connecting to ${uri}", ("uri",uri)); websocketpp::lib::error_code ec; my->_uri = uri; @@ -683,6 +686,9 @@ namespace fc { namespace http { auto con = my->_client.get_connection( uri, ec ); + std::for_each(headers.begin(), headers.end(), [con](std::pair in) { + con->append_header(in.first, in.second); + }); if( ec ) FC_ASSERT( !ec, "error: ${e}", ("e",ec.message()) ); my->_client.connect(con); @@ -695,7 +701,6 @@ namespace fc { namespace http { if( uri.substr(0,3) == "ws:" ) return connect(uri); FC_ASSERT( uri.substr(0,4) == "wss:" ); - // wlog( "connecting to ${uri}", ("uri",uri)); websocketpp::lib::error_code ec; smy->_uri = uri; @@ -729,6 +734,11 @@ namespace fc { namespace http { my->_closed->wait(); } + void websocket_client::append_header(const std::string& key, const std::string& value) + { + headers.push_back( std::pair(key, value)); + } + websocket_connection_ptr websocket_tls_client::connect( const std::string& uri ) { try { // wlog( "connecting to ${uri}", ("uri",uri)); diff --git a/tests/network/http/websocket_test.cpp b/tests/network/http/websocket_test.cpp index cfa78d0..ffad3f4 100644 --- a/tests/network/http/websocket_test.cpp +++ b/tests/network/http/websocket_test.cpp @@ -59,6 +59,65 @@ BOOST_AUTO_TEST_CASE(websocket_test) BOOST_CHECK_THROW(c_conn->send_message( "again" ), fc::assert_exception); BOOST_CHECK_THROW(client.connect( "ws://localhost:" + fc::to_string(port) ), fc::exception); + l.remove_appender(ca); +} + +BOOST_AUTO_TEST_CASE(websocket_test_with_proxy_header) +{ + // set up logging + fc::shared_ptr ca(new fc::console_appender); + fc::logger l = fc::logger::get("rpc"); + l.add_appender( ca ); + + fc::http::websocket_client client; + // add the proxy header element + client.append_header("MyProxyHeaderKey", "MyServer:8080"); + + fc::http::websocket_connection_ptr s_conn, c_conn; + int port; + { + // the server will be on the lookout for the key in the header + fc::http::websocket_server server("MyProxyHeaderKey"); + server.on_connection([&]( const fc::http::websocket_connection_ptr& c ){ + s_conn = c; + c->on_message_handler([&](const std::string& s){ + c->send_message("echo: " + s); + }); + }); + + server.listen( 0 ); + port = server.get_listening_port(); + + server.start_accept(); + + std::string echo; + c_conn = client.connect( "ws://localhost:" + fc::to_string(port) ); + c_conn->on_message_handler([&](const std::string& s){ + echo = s; + }); + c_conn->send_message( "hello world" ); + fc::usleep( fc::milliseconds(100) ); + BOOST_CHECK_EQUAL("echo: hello world", echo); + c_conn->send_message( "again" ); + fc::usleep( fc::milliseconds(100) ); + BOOST_CHECK_EQUAL("echo: again", echo); + + s_conn->close(0, "test"); + fc::usleep( fc::milliseconds(100) ); + BOOST_CHECK_THROW(c_conn->send_message( "again" ), fc::exception); + + c_conn = client.connect( "ws://localhost:" + fc::to_string(port) ); + c_conn->on_message_handler([&](const std::string& s){ + echo = s; + }); + c_conn->send_message( "hello world" ); + fc::usleep( fc::milliseconds(100) ); + BOOST_CHECK_EQUAL("echo: hello world", echo); + } + + BOOST_CHECK_THROW(c_conn->send_message( "again" ), fc::assert_exception); + BOOST_CHECK_THROW(client.connect( "ws://localhost:" + fc::to_string(port) ), fc::exception); + l.remove_appender(ca); } BOOST_AUTO_TEST_SUITE_END() diff --git a/tests/ws_test_server.cpp b/tests/ws_test_server.cpp index 14700a9..e088681 100644 --- a/tests/ws_test_server.cpp +++ b/tests/ws_test_server.cpp @@ -13,7 +13,7 @@ int main(int argc, char** argv) fc::logger l = fc::logger::get("rpc"); l.add_appender( ca ); - fc::http::websocket_server server; + fc::http::websocket_server server("MyForwardHeaderKey"); server.on_connection([&]( const fc::http::websocket_connection_ptr& c ){ c->on_message_handler([&](const std::string& s){