refactor: implement socket structure and server refactor

This commit is contained in:
Quinten 2025-10-14 18:38:57 +02:00
parent 970ab847fc
commit f3bdf28eed
17 changed files with 324 additions and 310 deletions

View File

@ -1,4 +1,4 @@
#include <webserv/socket/Socket.hpp> #include <webserv/socket/ServerSocket.hpp>
#include <gtest/gtest.h> #include <gtest/gtest.h>
@ -23,7 +23,7 @@ class SocketTest : public ::testing::Test
TEST_F(SocketTest, DefaultConstructor) TEST_F(SocketTest, DefaultConstructor)
{ {
Socket socket; ServerSocket socket;
// Socket should be created successfully // Socket should be created successfully
// We can't test much without actually creating network resources // We can't test much without actually creating network resources
SUCCEED(); SUCCEED();
@ -32,7 +32,7 @@ TEST_F(SocketTest, DefaultConstructor)
TEST_F(SocketTest, ConstructorWithFd) TEST_F(SocketTest, ConstructorWithFd)
{ {
// Socket constructor with invalid fd throws an exception // Socket constructor with invalid fd throws an exception
EXPECT_THROW(Socket socket(-1), std::runtime_error); EXPECT_THROW(ServerSocket socket(-1), std::runtime_error);
} }
TEST_F(SocketTest, MoveConstructor) TEST_F(SocketTest, MoveConstructor)

View File

@ -1,9 +1,9 @@
#include <webserv/client/Client.hpp> #include <webserv/client/Client.hpp>
#include <webserv/http/HttpHeaders.hpp> // for HttpHeaders #include <webserv/http/HttpHeaders.hpp> // for HttpHeaders
#include <webserv/log/Log.hpp> // for Log, LOCATION #include <webserv/log/Log.hpp> // for Log, LOCATION
#include <webserv/router/Router.hpp> // for Router #include <webserv/router/Router.hpp> // for Router
#include <webserv/server/Server.hpp> // for Server #include <webserv/server/Server.hpp> // for Server
#include <webserv/socket/Socket.hpp> // for Socket #include <webserv/socket/ClientSocket.hpp> // for Socket
#include <cstdint> // for uint8_t #include <cstdint> // for uint8_t
#include <functional> // for ref, reference_wrapper #include <functional> // for ref, reference_wrapper
@ -13,7 +13,7 @@
#include <sys/types.h> // for ssize_t #include <sys/types.h> // for ssize_t
Client::Client(std::unique_ptr<Socket> socket, Server &server) Client::Client(std::unique_ptr<ClientSocket> socket, Server &server)
: httpRequest_(std::make_unique<HttpRequest>(this)), httpResponse_(std::make_unique<HttpResponse>()), : httpRequest_(std::make_unique<HttpRequest>(this)), httpResponse_(std::make_unique<HttpResponse>()),
client_socket_(std::move(socket)), server_(std::ref(server)) client_socket_(std::move(socket)), server_(std::ref(server))
{ {
@ -44,7 +44,7 @@ void Client::request()
{ {
Log::trace(LOCATION); Log::trace(LOCATION);
char buffer[bufferSize_] = {}; // NOLINT(cppcoreguidelines-avoid-c-arrays) char buffer[bufferSize_] = {}; // NOLINT(cppcoreguidelines-avoid-c-arrays)
ssize_t bytesRead = client_socket_->recv( ssize_t bytesRead = client_socket_->read(
buffer, sizeof(buffer) - 1); // NOLINT(cppcoreguidelines-pro-bounds-array-to-pointer-decay) buffer, sizeof(buffer) - 1); // NOLINT(cppcoreguidelines-pro-bounds-array-to-pointer-decay)
if (bytesRead < 0) if (bytesRead < 0)
{ {

View File

@ -2,12 +2,14 @@
// #include <webserv/http/HttpResponse.hpp> // #include <webserv/http/HttpResponse.hpp>
#include "webserv/socket/ClientSocket.hpp"
#include <webserv/config/ServerConfig.hpp> // for ServerConfig #include <webserv/config/ServerConfig.hpp> // for ServerConfig
#include <webserv/http/HttpConstants.hpp> // for OK #include <webserv/http/HttpConstants.hpp> // for OK
#include <webserv/http/HttpRequest.hpp> // for HttpRequest #include <webserv/http/HttpRequest.hpp> // for HttpRequest
#include <webserv/http/HttpResponse.hpp> // for HttpResponse #include <webserv/http/HttpResponse.hpp> // for HttpResponse
#include <webserv/server/Server.hpp> #include <webserv/server/Server.hpp>
#include <webserv/socket/Socket.hpp> // for Socket #include <webserv/socket/ClientSocket.hpp> // for Socket
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <cstdint> // for uint8_t #include <cstdint> // for uint8_t
@ -15,14 +17,14 @@
#include <vector> // for vector #include <vector> // for vector
class Server; class Server;
class Socket; class ClientSocket;
class ServerConfig; class ServerConfig;
class HttpResponse; class HttpResponse;
class Client class Client
{ {
public: public:
Client(std::unique_ptr<Socket> socket, Server &server); Client(std::unique_ptr<ClientSocket> socket, Server &server);
Client(const Client &other) = delete; // Disable copy constructor Client(const Client &other) = delete; // Disable copy constructor
Client &operator=(const Client &other) = delete; // Disable copy assignment Client &operator=(const Client &other) = delete; // Disable copy assignment
@ -37,7 +39,7 @@ class Client
[[nodiscard]] bool isResponseReady() const; [[nodiscard]] bool isResponseReady() const;
[[nodiscard]] int getStatusCode() const; [[nodiscard]] int getStatusCode() const;
[[nodiscard]] Socket &getSocket() const { return *client_socket_; } [[nodiscard]] ClientSocket &getSocket() const { return *client_socket_; }
// void setError(int statusCode); // void setError(int statusCode);
@ -48,7 +50,7 @@ class Client
constexpr static size_t bufferSize_ = 4096; constexpr static size_t bufferSize_ = 4096;
std::unique_ptr<HttpRequest> httpRequest_ = nullptr; std::unique_ptr<HttpRequest> httpRequest_ = nullptr;
std::unique_ptr<HttpResponse> httpResponse_ = nullptr; std::unique_ptr<HttpResponse> httpResponse_ = nullptr;
std::unique_ptr<Socket> client_socket_; std::unique_ptr<ClientSocket> client_socket_;
Server &server_; Server &server_;
// mutable const ServerConfig *server_config_ = nullptr; // mutable const ServerConfig *server_config_ = nullptr;
}; };

View File

@ -49,7 +49,7 @@ class Log
void log(Level level, const std::string &message, const std::map<std::string, std::string> &context); void log(Level level, const std::string &message, const std::map<std::string, std::string> &context);
static constexpr Log::Level COMPILE_TIME_LOG_LEVEL = Log::Level::Info; static constexpr Log::Level COMPILE_TIME_LOG_LEVEL = Log::Level::Trace;
static void setFileChannel(const std::string &filename, std::ios_base::openmode mode = std::ios_base::app); static void setFileChannel(const std::string &filename, std::ios_base::openmode mode = std::ios_base::app);
static void setStdoutChannel(); static void setStdoutChannel();

View File

@ -6,7 +6,6 @@
#include <iostream> // for ios_base #include <iostream> // for ios_base
#include <string> // for allocator, basic_string, char_traits, operator+, string #include <string> // for allocator, basic_string, char_traits, operator+, string
#include <vector> // for vector
int main(int argc, char **argv) int main(int argc, char **argv)
{ {
@ -37,6 +36,6 @@ int main(int argc, char **argv)
Log::debug("ConfigManager initialized successfully."); Log::debug("ConfigManager initialized successfully.");
Server server(configManager); Server server(configManager);
server.start(); server.run();
return 0; return 0;
} }

View File

@ -1,9 +1,12 @@
#include "webserv/socket/ASocket.hpp"
#include <webserv/client/Client.hpp> // for Client #include <webserv/client/Client.hpp> // for Client
#include <webserv/config/ConfigManager.hpp> // for ConfigManager #include <webserv/config/ConfigManager.hpp> // for ConfigManager
#include <webserv/config/ServerConfig.hpp> // for ServerConfig #include <webserv/config/ServerConfig.hpp> // for ServerConfig
#include <webserv/log/Log.hpp> // for Log, LOCATION #include <webserv/log/Log.hpp> // for Log, LOCATION
#include <webserv/server/Server.hpp> #include <webserv/server/Server.hpp>
#include <webserv/socket/Socket.hpp> // for Socket #include <webserv/socket/ClientSocket.hpp> // for ClientSocket
#include <webserv/socket/ServerSocket.hpp> // for ServerSocket
#include <cerrno> // for errno #include <cerrno> // for errno
#include <cstring> // for strerror #include <cstring> // for strerror
@ -38,6 +41,15 @@ Server::Server(const ConfigManager &configManager)
Log::fatal("epoll_create1 failed"); Log::fatal("epoll_create1 failed");
throw std::runtime_error("epoll_create1 failed"); throw std::runtime_error("epoll_create1 failed");
} }
for (const auto &config : configManager_.getServerConfigs())
{
setupServerSocket(*config);
}
if (listener_fds_.empty())
{
Log::fatal("No server sockets created.");
throw std::runtime_error("No server sockets created.");
}
} }
Server::~Server() Server::~Server()
@ -49,27 +61,13 @@ Server::~Server()
} }
} }
void Server::start() void Server::add(const ASocket &socket, uint32_t events, Client *client)
{ {
Log::trace(LOCATION); if (socket.getType() != ASocket::Type::SERVER_SOCKET && client == nullptr)
Log::info("Starting servers...");
// 1. Load server configurations
for (const auto &config : configManager_.getServerConfigs())
{ {
setupServerSocket(*config); Log::error("Client pointer must be provided for non-server sockets");
throw std::invalid_argument("Client pointer must be provided for non-server sockets");
} }
if (listener_fds_.empty())
{
Log::fatal("No server sockets created.");
throw std::runtime_error("No server sockets created.");
}
eventLoop();
}
void Server::add(const Socket &socket, uint32_t events) const
{
Log::trace(LOCATION); Log::trace(LOCATION);
int fd = socket.getFd(); int fd = socket.getFd();
struct epoll_event event{}; struct epoll_event event{};
@ -80,24 +78,27 @@ void Server::add(const Socket &socket, uint32_t events) const
Log::error("epoll_ctl ADD failed for fd: " + std::to_string(fd)); Log::error("epoll_ctl ADD failed for fd: " + std::to_string(fd));
throw std::runtime_error("epoll_ctl ADD failed"); throw std::runtime_error("epoll_ctl ADD failed");
} }
socketToClient_[fd] = client;
}
void Server::remove(const ASocket &socket)
{
Log::trace(LOCATION);
int fd = socket.getFd();
if (epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr) == -1)
{
Log::error("epoll_ctl DEL failed for fd: " + std::to_string(fd));
throw std::runtime_error("epoll_ctl DEL failed");
}
socketToClient_.erase(fd);
} }
void Server::disconnect(const Client &client) void Server::disconnect(const Client &client)
{ {
Log::trace(LOCATION); Log::trace(LOCATION);
int client_fd = client.getSocket().getFd(); int client_fd = client.getSocket().getFd();
clients_.erase(client_fd);
}
void Server::remove(const Socket &socket) const std::erase_if(clients_, [&](const std::unique_ptr<Client> &c) { return c->getSocket().getFd() == client_fd; });
{
Log::trace(LOCATION);
int filedes = socket.getFd();
if (epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, filedes, nullptr) == -1)
{
Log::error("epoll_ctl DEL failed for fd: " + std::to_string(filedes));
throw std::runtime_error("epoll_ctl DEL failed");
}
} }
void Server::setupServerSocket(const ServerConfig &config) void Server::setupServerSocket(const ServerConfig &config)
@ -106,9 +107,8 @@ void Server::setupServerSocket(const ServerConfig &config)
try try
{ {
auto host = config.get<std::string>("host").value_or(std::string()); // TODO should not be a default host auto host = config.get<std::string>("host").value_or(std::string()); // TODO should not be a default host
auto port = config.get<int>("listen").value_or(0); // TODO should not be a default port
auto port = config.get<int>("listen").value_or(0); // TODO should not be a default port std::unique_ptr<ServerSocket> serverSocket = std::make_unique<ServerSocket>();
std::unique_ptr<Socket> serverSocket = std::make_unique<Socket>();
serverSocket->bind(host, port); serverSocket->bind(host, port);
serverSocket->listen(SOMAXCONN); serverSocket->listen(SOMAXCONN);
int server_fd = serverSocket->getFd(); int server_fd = serverSocket->getFd();
@ -118,7 +118,6 @@ void Server::setupServerSocket(const ServerConfig &config)
listeners_.push_back(std::move(serverSocket)); listeners_.push_back(std::move(serverSocket));
listener_fds_.insert(server_fd); listener_fds_.insert(server_fd);
Log::info("Server listening on " + host + ":" + std::to_string(port) + "..."); Log::info("Server listening on " + host + ":" + std::to_string(port) + "...");
// static_cast<std::string>(config["listen"]) + "...");
} }
catch (const std::exception &e) catch (const std::exception &e)
{ {
@ -129,13 +128,15 @@ void Server::setupServerSocket(const ServerConfig &config)
void Server::handleConnection(struct epoll_event *event) void Server::handleConnection(struct epoll_event *event)
{ {
Log::trace(LOCATION); Log::trace(LOCATION);
Socket &listener = getListener(event->data.fd); ServerSocket &listener = getListener(event->data.fd);
std::unique_ptr<Socket> clientSocket = listener.accept(); std::unique_ptr<ClientSocket> clientSocket = listener.accept();
add(*clientSocket, EPOLLIN);
clients_.insert({clientSocket->getFd(), std::make_unique<Client>(std::move(clientSocket), *this)}); auto client = std::make_unique<Client>(std::move(clientSocket), *this);
add(client->getSocket(), EPOLLIN, client.get());
clients_.emplace_back(std::move(client));
} }
Socket &Server::getListener(int fd) const ServerSocket &Server::getListener(int fd) const
{ {
Log::trace(LOCATION); Log::trace(LOCATION);
for (const auto &listener : listeners_) for (const auto &listener : listeners_)
@ -152,10 +153,9 @@ Socket &Server::getListener(int fd) const
Client &Server::getClient(int fd) const Client &Server::getClient(int fd) const
{ {
Log::trace(LOCATION); Log::trace(LOCATION);
auto it = clients_.find(fd); if (socketToClient_.contains(fd))
if (it != clients_.end())
{ {
return *(it->second); return *(socketToClient_.at(fd));
} }
Log::error("Client not found for fd: " + std::to_string(fd)); Log::error("Client not found for fd: " + std::to_string(fd));
throw std::runtime_error("Client not found for fd: " + std::to_string(fd)); throw std::runtime_error("Client not found for fd: " + std::to_string(fd));
@ -198,7 +198,7 @@ void Server::handleResponse(struct epoll_event *event)
{ {
Log::debug("Sent " + std::to_string(bytesSent) + " bytes to fd: " + std::to_string(event->data.fd)); Log::debug("Sent " + std::to_string(bytesSent) + " bytes to fd: " + std::to_string(event->data.fd));
} }
clients_.erase(event->data.fd); disconnect(client);
} }
void Server::handleEvent(struct epoll_event *event) void Server::handleEvent(struct epoll_event *event)
@ -224,7 +224,7 @@ void Server::handleEvent(struct epoll_event *event)
} }
} }
void Server::eventLoop() void Server::run()
{ {
Log::trace(LOCATION); Log::trace(LOCATION);
Log::info("Listening..."); Log::info("Listening...");

View File

@ -1,10 +1,12 @@
#pragma once #pragma once
#include "webserv/socket/ASocket.hpp"
#include <webserv/client/Client.hpp> #include <webserv/client/Client.hpp>
#include <webserv/config/ConfigManager.hpp> #include <webserv/config/ConfigManager.hpp>
#include <webserv/config/ServerConfig.hpp> // for ServerConfig #include <webserv/config/ServerConfig.hpp> // for ServerConfig
#include <webserv/router/Router.hpp> // for Router #include <webserv/router/Router.hpp> // for Router
#include <webserv/socket/Socket.hpp> // for Socket #include <webserv/socket/ServerSocket.hpp> // for ServerSocket
#include <cstdint> // for uint32_t #include <cstdint> // for uint32_t
#include <memory> // for unique_ptr #include <memory> // for unique_ptr
@ -32,13 +34,13 @@ class Server
~Server(); ~Server();
void start(); void run();
void add(const Socket &socket, uint32_t events) const; void add(const ASocket &socket, uint32_t events, Client *client = nullptr);
void remove(const Socket &socket) const; void remove(const ASocket &socket);
void disconnect(const Client &client); void disconnect(const Client &client);
void responseReady(int client_fd) const; void responseReady(int client_fd) const;
Socket &getListener(int fd) const; ServerSocket &getListener(int fd) const;
Client &getClient(int fd) const; Client &getClient(int fd) const;
const Router &getRouter() const; const Router &getRouter() const;
@ -46,9 +48,11 @@ class Server
int epoll_fd_; int epoll_fd_;
const ConfigManager &configManager_; const ConfigManager &configManager_;
const Router router_; const Router router_;
std::vector<std::unique_ptr<Socket>> listeners_; std::vector<std::unique_ptr<ServerSocket>> listeners_;
std::set<int> listener_fds_; std::set<int> listener_fds_;
std::unordered_map<int, std::unique_ptr<Client>> clients_; // std::unordered_map<int, std::unique_ptr<Client>> clients_;
std::vector<std::unique_ptr<Client>> clients_;
std::unordered_map<int, Client *> socketToClient_;
void handleEvent(struct epoll_event *event); void handleEvent(struct epoll_event *event);
void handleConnection(struct epoll_event *event); void handleConnection(struct epoll_event *event);
@ -56,5 +60,4 @@ class Server
void handleResponse(struct epoll_event *event); void handleResponse(struct epoll_event *event);
void setupServerSocket(const ServerConfig &config); void setupServerSocket(const ServerConfig &config);
void eventLoop(); };
};

View File

@ -0,0 +1,66 @@
#include <webserv/log/Log.hpp> // for Log, LOCATION
#include <webserv/socket/ASocket.hpp>
#include <fcntl.h> // For fcntl()
#include <sys/socket.h>
#include <unistd.h> // For close()
ASocket::ASocket(int fd) : fd_(fd)
{
Log::trace(LOCATION);
if (fd_ == -1)
{
Log::error("Invalid file descriptor");
throw std::runtime_error("Invalid file descriptor");
}
setNonBlocking();
}
ASocket::~ASocket()
{
Log::trace(LOCATION);
if (fd_ != -1)
{
close(fd_);
}
}
ssize_t ASocket::read(void *buf, size_t len) const
{
ssize_t bytesRead = ::recv(fd_, buf, len, 0);
if (bytesRead == -1)
{
throw std::system_error(errno, std::generic_category(), "Socket: Read error");
}
return bytesRead;
}
ssize_t ASocket::write(const void *buf, size_t len) const
{
Log::trace(LOCATION);
ssize_t bytesSent = ::send(fd_, buf, len, 0);
if (bytesSent == -1)
{
throw std::system_error(errno, std::generic_category(), "Socket: Write error");
}
return bytesSent;
}
void ASocket::setNonBlocking() const
{
if (fcntl(fd_, F_SETFL, O_NONBLOCK) == -1)
{
throw std::system_error(errno, std::generic_category(), "ASocket: Failed to set FD non-blocking");
}
}
int ASocket::getFd() const
{
Log::trace(LOCATION);
return fd_;
}
void ASocket::setFd(int fd)
{
fd_ = fd;
}

View File

@ -0,0 +1,40 @@
#pragma once
#include <cstddef> // for size_t
#include <cstdint>
#include <sys/types.h> // for ssize_t
class ASocket
{
public:
enum class Type : uint8_t
{
CLIENT_SOCKET,
SERVER_SOCKET,
CGI_SOCKET
};
ASocket() = delete;
explicit ASocket(int fd);
ASocket(const ASocket &other) = delete;
ASocket &operator=(const ASocket &other) = delete;
ASocket(ASocket &&other) noexcept = default;
ASocket &operator=(ASocket &&other) noexcept = default;
virtual ~ASocket();
[[nodiscard]] virtual Type getType() const = 0;
[[nodiscard]] int getFd() const;
ssize_t read(void *buf, size_t len) const;
ssize_t write(const void *buf, size_t len) const;
protected:
void setNonBlocking() const;
void setFd(int fd);
private:
int fd_;
};

View File

@ -1,74 +1,12 @@
#include <webserv/log/Log.hpp>
#include <webserv/socket/CGISocket.hpp> #include <webserv/socket/CGISocket.hpp>
#include <cstring> // for strerror CGISocket::CGISocket(int fd) : ASocket(fd)
#include <stdexcept> // for runtime_error
#include <string> // for string, operator+
#include <system_error> // for system_error, error_code
#include <fcntl.h> // for fcntl, O_NONBLOCK
#include <unistd.h> // for close, read, write
CGISocket::CGISocket(int readFd, int writeFd) : _readFd(readFd), _writeFd(writeFd)
{ {
if (_readFd == -1 || _writeFd == -1) Log::trace(LOCATION);
{
throw std::runtime_error("CGISocket: Invalid file descriptors");
}
setNonBlocking();
} }
CGISocket::~CGISocket() ASocket::Type CGISocket::getType() const
{ {
if (_readFd != -1) return ASocket::Type::CGI_SOCKET;
{
close(_readFd);
}
if (_writeFd != -1)
{
close(_writeFd);
}
} }
int CGISocket::getReadFd() const
{
return _readFd;
}
int CGISocket::getWriteFd() const
{
return _writeFd;
}
ssize_t CGISocket::read(void *buffer, size_t size) const
{
ssize_t bytesRead = ::read(_readFd, buffer, size);
if (bytesRead == -1)
{
throw std::system_error(errno, std::generic_category(), "CGISocket: Read error");
}
return bytesRead;
}
ssize_t CGISocket::write(const void *buffer, size_t size) const
{
ssize_t bytesWritten = ::write(_writeFd, buffer, size);
if (bytesWritten == -1)
{
throw std::system_error(errno, std::generic_category(), "CGISocket: Write error");
}
return bytesWritten;
}
void CGISocket::setNonBlocking() const
{
if (fcntl(_readFd, F_SETFL, O_NONBLOCK) == -1)
{
throw std::system_error(errno, std::generic_category(), "CGISocket: Failed to set read FD non-blocking");
}
if (fcntl(_writeFd, F_SETFL, O_NONBLOCK) == -1)
{
throw std::system_error(errno, std::generic_category(), "CGISocket: Failed to set write FD non-blocking");
}
}

View File

@ -1,32 +1,15 @@
#pragma once #pragma once
#include <webserv/socket/ASocket.hpp>
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/types.h> #include <sys/types.h>
#include <unistd.h> #include <unistd.h>
class CGISocket class CGISocket : public ASocket
{ {
public: public:
CGISocket(int readFd, int writeFd); [[nodiscard]] ASocket::Type getType() const override;
CGISocket(const CGISocket &) = delete;
CGISocket &operator=(const CGISocket &) = delete;
CGISocket(CGISocket &&other) noexcept = default;
CGISocket &operator=(CGISocket &&other) noexcept = default;
~CGISocket();
// Get FDs for server epoll registration
[[nodiscard]] int getReadFd() const; // Server reads CGI output
[[nodiscard]] int getWriteFd() const; // Server writes to CGI input
// Handler interface
ssize_t read(void *buffer, size_t size) const;
ssize_t write(const void *buffer, size_t size) const;
void setNonBlocking() const;
private:
int _readFd = -1; // Server side of output pipe
int _writeFd = -1; // Server side of input pipe
explicit CGISocket(int fd);
}; };

View File

@ -0,0 +1,13 @@
#include <webserv/log/Log.hpp>
#include <webserv/socket/ClientSocket.hpp>
ClientSocket::ClientSocket(int fd) : ASocket(fd)
{
Log::trace(LOCATION);
}
ASocket::Type ClientSocket::getType() const
{
Log::trace(LOCATION);
return ASocket::Type::CLIENT_SOCKET;
}

View File

@ -0,0 +1,15 @@
#pragma once
#include <webserv/socket/ASocket.hpp>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
class ClientSocket : public ASocket
{
public:
[[nodiscard]] ASocket::Type getType() const override;
explicit ClientSocket(int fd);
};

View File

@ -0,0 +1,83 @@
#include "webserv/socket/ASocket.hpp"
#include "webserv/socket/ClientSocket.hpp"
#include <webserv/log/Log.hpp>
#include <webserv/socket/ServerSocket.hpp>
#include <memory>
#include <stdexcept>
#include <arpa/inet.h> // For inet_addr
#include <fcntl.h> // For fcntl()"
#include <netinet/in.h> // For sockaddr_in
#include <sys/socket.h>
#include <unistd.h> // For close()
ServerSocket::ServerSocket() : ASocket(socket(AF_INET, SOCK_STREAM, 0))
{
Log::trace(LOCATION);
if (getFd() == -1)
{
Log::error("Socket creation failed");
throw std::runtime_error("Socket creation failed");
}
int opt = 1;
if (setsockopt(getFd(), SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0)
{
close(getFd());
setFd(-1);
Log::error("setsockopt failed");
throw std::runtime_error("setsockopt failed");
}
setNonBlocking();
}
ServerSocket::ServerSocket(int fd) : ASocket(fd)
{
Log::trace(LOCATION);
}
void ServerSocket::listen(int backlog) const
{
Log::trace(LOCATION);
if (::listen(getFd(), backlog) < 0)
{
Log::error("Listen failed");
throw std::runtime_error("Listen failed");
}
}
void ServerSocket::bind(const std::string &host, const int port) const
{
Log::trace(LOCATION);
struct sockaddr_in address{};
address.sin_family = AF_INET;
address.sin_addr.s_addr = inet_addr(host.c_str());
address.sin_port = htons(port);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
if (::bind(getFd(), reinterpret_cast<struct sockaddr *>(&address), sizeof(address)) < 0)
{
Log::fatal("Cannot bind to " + host + ":" + std::to_string(port)
+ " - address already in use or permission denied");
throw std::runtime_error("Bind failed");
}
}
ASocket::Type ServerSocket::getType() const
{
Log::trace(LOCATION);
return ASocket::Type::SERVER_SOCKET;
}
std::unique_ptr<ClientSocket> ServerSocket::accept() const
{
Log::trace(LOCATION);
int client_fd = ::accept(getFd(), nullptr, nullptr);
if (client_fd < 0)
{
Log::error("Accept failed");
throw std::runtime_error("Accept failed");
}
return std::make_unique<ClientSocket>(client_fd);
}

View File

@ -0,0 +1,22 @@
#pragma once
#include "webserv/socket/ClientSocket.hpp"
#include <webserv/socket/ASocket.hpp>
#include <memory> // for unique_ptr
#include <string> // for string
class ServerSocket : public ASocket
{
public:
ServerSocket();
ServerSocket(int fd);
void listen(int backlog) const;
void bind(const std::string &host, int port) const;
[[nodiscard]] ASocket::Type getType() const override;
[[nodiscard]] std::unique_ptr<ClientSocket> accept() const;
};

View File

@ -1,118 +0,0 @@
#include <webserv/socket/Socket.hpp>
#include <webserv/log/Log.hpp>
#include <memory>
#include <stdexcept>
#include <arpa/inet.h> // For inet_addr
#include <fcntl.h> // For fcntl()"
#include <netinet/in.h> // For sockaddr_in
#include <sys/socket.h>
#include <unistd.h> // For close()
Socket::Socket() : fd_(socket(AF_INET, SOCK_STREAM, 0))
{
Log::trace(LOCATION);
if (fd_ == -1)
{
Log::error("Socket creation failed");
throw std::runtime_error("Socket creation failed");
}
int opt = 1;
if (setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0)
{
close(fd_);
fd_ = -1;
Log::error("setsockopt failed");
throw std::runtime_error("setsockopt failed");
}
setNonBlocking();
}
Socket::Socket(int fd) : fd_(fd) // NOLINT(readability-identifier-naming)
{
Log::trace(LOCATION);
if (fd_ == -1)
{
Log::error("Invalid file descriptor");
throw std::runtime_error("Invalid file descriptor");
}
setNonBlocking();
}
Socket::~Socket()
{
Log::trace(LOCATION);
if (fd_ != -1)
{
close(fd_);
}
}
void Socket::listen(int backlog) const
{
Log::trace(LOCATION);
if (::listen(fd_, backlog) < 0)
{
Log::error("Listen failed");
throw std::runtime_error("Listen failed");
}
}
void Socket::bind(const std::string &host, const int port) const
{
Log::trace(LOCATION);
struct sockaddr_in address{};
address.sin_family = AF_INET;
address.sin_addr.s_addr = inet_addr(host.c_str());
address.sin_port = htons(port);
if (::bind(fd_, reinterpret_cast<struct sockaddr *>(&address), sizeof(address))
< 0) // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
{
Log::fatal("Cannot bind to " + host + ":" + std::to_string(port)
+ " - address already in use or permission denied");
throw std::runtime_error("Bind failed");
}
}
std::unique_ptr<Socket> Socket::accept() const
{
Log::trace(LOCATION);
int client_fd = ::accept(fd_, nullptr, nullptr);
if (client_fd < 0)
{
Log::error("Accept failed");
throw std::runtime_error("Accept failed");
}
return std::make_unique<Socket>(client_fd);
}
ssize_t Socket::recv(void *buf, size_t len) const
{
Log::trace(LOCATION);
return ::recv(fd_, buf, len, 0);
}
ssize_t Socket::send(const void *buf, size_t len) const
{
Log::trace(LOCATION);
return ::send(fd_, buf, len, 0);
}
void Socket::setNonBlocking() const
{
Log::trace(LOCATION);
if (fcntl(fd_, F_SETFL, O_NONBLOCK) < 0)
{
Log::error("Failed to set non-blocking mode");
throw std::runtime_error("Failed to set non-blocking mode");
}
}
int Socket::getFd() const
{
Log::trace(LOCATION);
return fd_;
}

View File

@ -1,32 +0,0 @@
#pragma once
#include <cstddef> // for size_t
#include <memory> // for unique_ptr
#include <string> // for string
#include <sys/types.h> // for ssize_t
class Socket
{
public:
Socket();
Socket(int fd); // NOLINT readability-identifier-naming
Socket(const Socket &other) = delete; // Disable copy constructor
Socket &operator=(const Socket &other) = delete; // Disable copy assignment
Socket(Socket &&other) noexcept = default; // Move constructor
Socket &operator=(Socket &&other) noexcept = default; // Move assignment
~Socket();
void listen(int backlog) const;
void bind(const std::string &host, int port) const;
[[nodiscard]] std::unique_ptr<Socket> accept() const;
ssize_t recv(void *buf, size_t len) const;
ssize_t send(const void *buf, size_t len) const;
void setNonBlocking() const;
[[nodiscard]] int getFd() const;
private:
int fd_;
};