Skip to content

Commit

Permalink
Merge pull request #5 from kriegalex/develop
Browse files Browse the repository at this point in the history
v1.0.0-beta2
  • Loading branch information
kriegalex authored May 2, 2024
2 parents 5ebf81a + 9bda366 commit f6effb7
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 53 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ add_executable(${PROJECT_NAME} main.cpp
MemoryBoundedQueue.cpp)
target_link_libraries(SecureSyslogServer OpenSSL::SSL OpenSSL::Crypto)
if (WIN32)
target_link_libraries(${PROJECT_NAME} ws2_32)
target_link_libraries(${PROJECT_NAME} ws2_32 ntdll)
endif ()
target_include_directories(SecureSyslogServer PRIVATE ${OPENSSL_INCLUDE_DIR})

Expand Down
1 change: 0 additions & 1 deletion Logger.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "Logger.h"

#include <filesystem>
#include <sstream>

#include "ScreenLogger.h"
Expand Down
43 changes: 34 additions & 9 deletions SSLUtil.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#include "SSLUtil.h"

#include <iostream>
#include <stdexcept>
#include <openssl/err.h>
#ifdef OPENSSL_SYS_WINDOWS
#include <winsock.h>
#include <winsock2.h>
#include <ip2string.h>
#endif

SSL_CTX *SSLUtil::createServerContext() {
Expand Down Expand Up @@ -45,7 +48,7 @@ int SSLUtil::createSocket(int port) {
throw std::runtime_error("Unable to create socket.");
}

sockaddr_in addr;
sockaddr_in addr{};
addr.sin_family = AF_INET;
addr.sin_port = htons(port);
addr.sin_addr.s_addr = INADDR_ANY;
Expand All @@ -55,10 +58,6 @@ int SSLUtil::createSocket(int port) {
if (setsockopt(s, SOL_SOCKET, SO_REUSEADDR, (char *) &optval, sizeof(optval)) < 0) {
throw std::runtime_error("Unable to set socket option SO_REUSEADDR.");
}
// Set TCP_NODELAY for syslog performance tuning
if (setsockopt(s, IPPROTO_TCP, TCP_NODELAY, (char *) &optval, sizeof(optval)) < 0) {
throw std::runtime_error("Unable to set socket option TCP_NODELAY.");
}

if (bind(s, (struct sockaddr *) &addr, sizeof(addr)) < 0) {
throw std::runtime_error("Unable to bind to socket.");
Expand All @@ -71,12 +70,26 @@ int SSLUtil::createSocket(int port) {
return s;
}

std::string SSLUtil::getClientIP(int clientSocket) {
struct sockaddr_in client_addr{};
int addr_len = sizeof(client_addr);

// Retrieve client information
if (getpeername(clientSocket, (struct sockaddr *) &client_addr, &addr_len) == 0) {
char client_ip[16];
RtlIpv4AddressToStringA(&client_addr.sin_addr, client_ip);
return std::move(std::string(client_ip));
}
return std::move(std::string("Unknown"));
}

int SSLUtil::acceptClient(int serverSocket) {
sockaddr_in addr;
sockaddr_in addr{};
int len = sizeof(addr);
int client = accept(serverSocket, (struct sockaddr *) &addr, &len);
if (client < 0) {
throw std::runtime_error("Unable to accept client.");
if (client == INVALID_SOCKET) {
if (WSAGetLastError() != WSAEINTR) // the server was probably killed intentionally
throw std::runtime_error("Unable to accept client.");
}
return client;
}
Expand All @@ -90,3 +103,15 @@ SSL *SSLUtil::createSSL(SSL_CTX *ctx, int clientSocket) {
throw std::runtime_error("Unable to set the SSL file descriptor.");
return ssl;
}

void SSLUtil::setupClient(int clientSocket) {
int timeout = 60000; // Timeout in milliseconds, 60 sec
if (setsockopt(clientSocket, SOL_SOCKET, SO_RCVTIMEO, (const char *) &timeout, sizeof(timeout)) < 0) {
throw std::runtime_error("Unable to set socket option SO_RCVTIMEO.");
}
int optval = 1;
// Set TCP_NODELAY for syslog performance tuning
if (setsockopt(clientSocket, IPPROTO_TCP, TCP_NODELAY, (char *) &optval, sizeof(optval)) < 0) {
throw std::runtime_error("Unable to set socket option TCP_NODELAY.");
}
}
4 changes: 4 additions & 0 deletions SSLUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
#define SSLUTIL_H

#include <openssl/ssl.h>
#include <string>

class SSLUtil {
public:
static SSL_CTX *createServerContext();
static int createSocket(int port);
static int acceptClient(int serverSocket);
static void setupClient(int clientSocket);
static std::string getClientIP(int clientSocket);
static std::string sslErrorToString(int error);
static SSL *createSSL(SSL_CTX *ctx, int clientSocket);
static void initWinSocket();
static void cleanWinSocket();
Expand Down
98 changes: 71 additions & 27 deletions SyslogServer.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#include "SyslogServer.h"

#include <csignal>
#include <utility>
#include <openssl/err.h>

// Initialize static instance pointer
SyslogServer *SyslogServer::instance_ = nullptr;

SyslogServer::SyslogServer(const std::string &configPath) : config_(configPath),
logger_(config_),
client_socket_(-1),
ssl_(nullptr) {
SyslogServer::SyslogServer(const std::string &configPath)
: config_(configPath), logger_ptr_(std::make_shared<Logger>(config_)) {
instance_ = this;
ssl_ctx_ = SSLUtil::createServerContext();
SSLUtil::initWinSocket();
Expand All @@ -22,10 +23,10 @@ SyslogServer::~SyslogServer() {
}

void SyslogServer::shutdownServer(int sig) {
if(sig != SIGINT) // unexpected
if (sig != SIGINT) // unexpected
return;
std::cout << "Shutdown signal received" << std::endl;
if (instance_) {
if (instance_->running_) {
instance_->running_ = false;
instance_->cleanup();
}
Expand Down Expand Up @@ -63,47 +64,90 @@ void SyslogServer::run() {
}

void SyslogServer::acceptConnections() {
while (running_) {
client_socket_ = SSLUtil::acceptClient(server_socket_);
ssl_ = SSLUtil::createSSL(ssl_ctx_, client_socket_);
if (SSL_accept(ssl_) == 1) {
handleClient();
while (instance_->running_) {
int client_socket = SSLUtil::acceptClient(server_socket_);
if(client_socket != INVALID_SOCKET) {
SSLUtil::setupClient(client_socket);

std::string client_ip = SSLUtil::getClientIP(client_socket);
std::cout << "Client connected: " << client_ip << std::endl;
SSL *ssl = SSLUtil::createSSL(ssl_ctx_, client_socket);
// use shared pointer
auto thread = std::make_shared<SyslogServerThread>(ssl, client_socket, client_ip, logger_ptr_);

{ // save as weak_ptr to signal later without increasing ownership count
std::lock_guard<std::mutex> lock(shutdown_mutex_);
threads_.emplace_back(thread);
}

// use lambda to create a new thread and add it to the vector
std::thread([thread]() {
thread->run();
}).detach();
}
clientCleanup();
}
}

void SyslogServer::handleClient() {
char buffer[20 * 1024];
while (int len = SSL_read(ssl_, buffer, static_cast<int>(sizeof(buffer) - 1))) {
if (len > 0) {
buffer[len] = '\0';
logger_.processMessage(buffer);
void SyslogServer::cleanup() {
std::lock_guard<std::mutex> lock(shutdown_mutex_);
for(const auto& weak_thread: threads_) {
if(auto thread = weak_thread.lock()) {
thread->clientCleanup();
}
}
serverCleanup();
}

void SyslogServer::cleanup() {
std::cout << "Cleaning up sockets..." << std::endl;
clientCleanup();
serverCleanup();
void SyslogServer::serverCleanup() {
if (server_socket_ != -1) {
std::cout << "Cleaning up server socket..." << std::endl;
closesocket(server_socket_);
server_socket_ = -1;
}
}

SyslogServerThread::SyslogServerThread(SSL *ssl,
int client_socket,
std::string client_ip,
std::shared_ptr<Logger> logger_ptr)
: ssl_(ssl), client_socket_(client_socket), client_ip_(std::move(client_ip)), logger_ptr_(std::move(logger_ptr)) {}

void SyslogServerThread::handleClient() {
char buffer[20 * 1024] = {0};
int rx_len;
while ((rx_len = SSL_read(ssl_, buffer, static_cast<int>(sizeof(buffer) - 1))) > 0) {
buffer[rx_len] = '\0';
logger_ptr_->processMessage(buffer);
}
if(rx_len != 0) { // 0 is clean disconnect
int ssl_err = SSL_get_error(ssl_,rx_len);
auto err_err = ERR_get_error();
if(ssl_err == SSL_ERROR_SYSCALL) {
if(err_err != 0) // 0 is most probably an unexpected timeout/disconnect
std::cerr << "Socket I/O error" << std::endl;
} else {
std::cerr << "SSL " << ERR_error_string(err_err, NULL) << std::endl;
}
}
}

void SyslogServer::clientCleanup() {
void SyslogServerThread::clientCleanup() {
if (ssl_) {
SSL_shutdown(ssl_);
SSL_free(ssl_);
ssl_ = nullptr;
}
if (client_socket_ != -1) {
std::cout << "Client disconnected: " << client_ip_ << std::endl;
closesocket(client_socket_);
client_socket_ = -1;
}
}

void SyslogServer::serverCleanup() {
if (server_socket_ != -1) {
closesocket(server_socket_);
server_socket_ = -1;
void SyslogServerThread::run() {
if (SSL_accept(ssl_) == 1) {
handleClient();
}
// Cleanup the client connection
clientCleanup();
}
32 changes: 23 additions & 9 deletions SyslogServer.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
#pragma once

#include "Config.h"
#include <string>
#include <iostream>
#include <openssl/ssl.h>

#include "Config.h"
#include "SSLUtil.h"
#include "Logger.h"

class SyslogServerThread {
public:
SyslogServerThread(SSL *ssl,
int client_socket,
std::string client_ip,
std::shared_ptr<Logger> logger_ptr);
void run();
void clientCleanup();

private:
SSL *ssl_;
int client_socket_;
std::string client_ip_;
std::shared_ptr<Logger> logger_ptr_;

void handleClient();
};

class SyslogServer {
public:
explicit SyslogServer(const std::string &configPath);
Expand All @@ -17,11 +33,11 @@ class SyslogServer {

private:
Config config_;
Logger logger_;
SSL_CTX *ssl_ctx_;
std::shared_ptr<Logger> logger_ptr_;
SSL_CTX *ssl_ctx_{};
int server_socket_;
int client_socket_;
SSL *ssl_;
std::vector<std::weak_ptr<SyslogServerThread>> threads_;
std::mutex shutdown_mutex_;
bool running_{};

static SyslogServer *instance_;
Expand All @@ -30,7 +46,5 @@ class SyslogServer {
static void setupSignals();
static void enableVirtualTerminalProcessing();
void acceptConnections();
void handleClient();
void clientCleanup();
void serverCleanup();
};
7 changes: 1 addition & 6 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@
int main(int argc, char **argv) {
try {
SyslogServer server("config.json");
try {
server.run();
}
catch (const std::exception &e) {
server.cleanup();
}
server.run();
}
catch (const std::exception &e) {
std::cerr << "Error: " << e.what() << std::endl;
Expand Down

0 comments on commit f6effb7

Please sign in to comment.