refactor(Client, Server, ASocket): update socket handling and state management

This commit is contained in:
Quinten 2025-10-21 16:55:11 +02:00
parent eedb5382ea
commit cfb88745bb
13 changed files with 160 additions and 42 deletions

View File

@ -98,23 +98,33 @@ void Client::request()
}
}
//
void Client::setCgiSockets(CgiSocket *cgiStdIn, CgiSocket *cgiStdOut)
{
server_.add(*cgiStdIn, EPOLLOUT, this); // write
server_.add(*cgiStdOut, EPOLLIN, this); // read
// //
// void Client::setCgiSockets(CgiSocket *cgiStdIn, CgiSocket *cgiStdOut)
// {
// // server_.add(*cgiStdIn, EPOLLOUT, this); // write
// // server_.add(*cgiStdOut, EPOLLIN, this); // read
sockets_[cgiStdIn->getFd()] = cgiStdIn;
sockets_[cgiStdOut->getFd()] = cgiStdOut;
// sockets_[cgiStdIn->getFd()] = cgiStdIn;
// sockets_[cgiStdOut->getFd()] = cgiStdOut;
// }
// void Client::removeCgiSocket(CgiSocket *cgiSocket)
// {
// server_.remove(*cgiSocket); // write
// sockets_.erase(cgiSocket->getFd());
// }
void Client::addSocket(ASocket *socket)
{
server_.add(*socket, this);
sockets_[socket->getFd()] = socket;
}
void Client::removeCgiSocket(CgiSocket *cgiSocket)
void Client::removeSocket(ASocket *socket)
{
server_.remove(*cgiSocket); // write
sockets_.erase(cgiSocket->getFd());
// sockets_[cgiStdIn->getFd()] = cgiStdIn;
// sockets_[cgiStdOut->getFd()] = cgiStdOut;
server_.remove(*socket);
sockets_.erase(socket->getFd());
}
void Client::poll() const
@ -130,7 +140,8 @@ void Client::poll() const
{
Log::info("Response is ready to be sent to client, fd: " + std::to_string(clientSocket_->getFd()));
clientSocket_->setCallback([this]() { respond(); });
server_.writable(clientSocket_->getFd());
// server_.writable(clientSocket_->getFd());
clientSocket_->setIOState(ASocket::IOState::WRITE);
}
}

View File

@ -43,8 +43,11 @@ class Client
[[nodiscard]] ASocket &getSocket(int fd = -1) const;
// void setStatusCode(int code);
void setCgiSockets(CgiSocket *cgiStdIn, CgiSocket *cgiStdOut);
void removeCgiSocket(CgiSocket *cgiSocket);
// void setCgiSockets(CgiSocket *cgiStdIn, CgiSocket *cgiStdOut);
// void removeCgiSocket(CgiSocket *cgiSocket);
void addSocket(ASocket *socket);
void removeSocket(ASocket *socket);
[[nodiscard]] HttpRequest &getHttpRequest() const noexcept;
[[nodiscard]] HttpResponse &getHttpResponse() const noexcept;

View File

@ -34,7 +34,7 @@ void CgiHandler::write()
if (request_.getBody().empty())
{
Log::debug("No body to write to CGI stdin, fd: " + std::to_string(cgiStdIn_->getFd()));
request_.getClient().removeCgiSocket(cgiStdIn_.get());
request_.getClient().removeSocket(cgiStdIn_.get());
cgiStdIn_ = nullptr;
return;
}
@ -48,7 +48,7 @@ void CgiHandler::write()
Log::debug("Wrote " + std::to_string(bytesWritten)
+ " bytes to CGI stdin, fd: " + std::to_string(cgiStdIn_->getFd()));
}
request_.getClient().removeCgiSocket(cgiStdIn_.get());
request_.getClient().removeSocket(cgiStdIn_.get());
cgiStdIn_ = nullptr;
}
@ -70,7 +70,7 @@ void CgiHandler::read()
else if (bytesRead == 0)
{
Log::info("CGI process closed stdout, fd: " + std::to_string(cgiStdOut_->getFd()));
request_.getClient().removeCgiSocket(cgiStdOut_.get());
request_.getClient().removeSocket(cgiStdOut_.get());
cgiStdOut_ = nullptr;
parseCgiOutput();
return;
@ -92,9 +92,8 @@ void CgiHandler::setCgiSockets(std::unique_ptr<CgiSocket> cgiStdIn, std::unique_
cgiStdOut_ = std::move(cgiStdOut);
cgiStdIn_ = std::move(cgiStdIn);
request_.getClient().setCgiSockets(cgiStdIn_.get(), cgiStdOut_.get()); // write
// TODO add to handler
request_.getClient().addSocket(cgiStdIn_.get());
request_.getClient().addSocket(cgiStdOut_.get());
}
void CgiHandler::wait() noexcept

View File

@ -74,8 +74,8 @@ void CgiProcess::spawn()
else
{
// Parent process
auto cgiStdIn = std::make_unique<CgiSocket>(pipeStdin[1]);
auto cgiStdOut = std::make_unique<CgiSocket>(pipeStdout[0]);
auto cgiStdIn = std::make_unique<CgiSocket>(pipeStdin[1], ASocket::IOState::WRITE);
auto cgiStdOut = std::make_unique<CgiSocket>(pipeStdout[0], ASocket::IOState::READ);
close(pipeStdin[0]);
close(pipeStdout[1]);

View File

@ -1,3 +1,5 @@
#include "webserv/utils/utils.hpp"
#include <webserv/client/Client.hpp> // for Client
#include <webserv/config/ConfigManager.hpp> // for ConfigManager
#include <webserv/config/ServerConfig.hpp> // for ServerConfig
@ -59,7 +61,7 @@ Server::~Server()
}
}
void Server::add(const ASocket &socket, uint32_t events, Client *client)
void Server::add(ASocket &socket, Client *client)
{
if (socket.getType() != ASocket::Type::SERVER_SOCKET && client == nullptr)
{
@ -69,7 +71,7 @@ void Server::add(const ASocket &socket, uint32_t events, Client *client)
Log::trace(LOCATION);
int fd = socket.getFd();
struct epoll_event event{};
event.events = events;
event.events = utils::stateToEpoll(socket.getEvent());
event.data.fd = fd;
if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &event) == -1)
{
@ -78,9 +80,10 @@ void Server::add(const ASocket &socket, uint32_t events, Client *client)
}
Log::debug("Socket added to epoll, fd: " + std::to_string(fd));
socketToClient_[fd] = client;
sockets_.insert(&socket);
}
void Server::remove(const ASocket &socket)
void Server::remove(ASocket &socket)
{
Log::trace(LOCATION);
int fd = socket.getFd();
@ -90,6 +93,7 @@ void Server::remove(const ASocket &socket)
throw std::runtime_error("epoll_ctl DEL failed");
}
socketToClient_.erase(fd);
sockets_.erase(&socket);
}
void Server::disconnect(const Client &client)
@ -112,7 +116,7 @@ void Server::setupServerSocket(const ServerConfig &config)
serverSocket->listen(SOMAXCONN);
int server_fd = serverSocket->getFd();
add(*serverSocket, EPOLLIN);
add(*serverSocket);
listeners_.push_back(std::move(serverSocket));
listener_fds_.insert(server_fd);
@ -131,7 +135,7 @@ void Server::handleConnection(struct epoll_event *event)
std::unique_ptr<ClientSocket> clientSocket = listener.accept();
auto client = std::make_unique<Client>(std::move(clientSocket), *this);
add(client->getSocket(), EPOLLIN, client.get());
add(client->getSocket(), client.get());
clients_.emplace_back(std::move(client));
}
@ -182,6 +186,23 @@ void Server::writable(int client_fd) const
}
}
void Server::update(const ASocket &socket) const
{
Log::trace(LOCATION);
int socketFd = socket.getFd();
uint32_t events = utils::stateToEpoll(socket.getEvent());
Log::debug("Socket (" + std::to_string(socket.getFd()) + ") is being updated");
struct epoll_event evt{};
evt.events = events;
evt.data.fd = socketFd;
if (epoll_ctl(epoll_fd_, EPOLL_CTL_MOD, socketFd, &evt) == -1)
{
Log::error("epoll_ctl MOD failed for fd: " + std::to_string(socketFd));
throw std::runtime_error("epoll_ctl MOD failed");
}
}
void Server::handleResponse(struct epoll_event *event) const
{
int socket_fd = event->data.fd;
@ -254,6 +275,18 @@ void Server::pollClients() const
}
}
void Server::pollSockets()
{
for (auto *socket : sockets_)
{
if (socket->isDirty())
{
update(*socket);
socket->processed();
}
}
}
void Server::run()
{
Log::trace(LOCATION);
@ -262,6 +295,7 @@ void Server::run()
struct epoll_event events[MAX_EVENTS]; // NOLINT
while (true)
{
pollSockets();
pollClients();
handleEpoll(events, MAX_EVENTS);
}

View File

@ -33,8 +33,9 @@ class Server
~Server();
void run();
void add(const ASocket &socket, uint32_t events, Client *client = nullptr);
void remove(const ASocket &socket);
void add(ASocket &socket, Client *client = nullptr);
void remove(ASocket &socket);
void update(const ASocket &socket) const;
void disconnect(const Client &client);
void writable(int client_fd) const;
@ -48,8 +49,10 @@ class Server
std::set<int> listener_fds_;
std::vector<std::unique_ptr<Client>> clients_;
std::unordered_map<int, Client *> socketToClient_;
std::set<ASocket *> sockets_;
void pollClients() const;
void pollSockets();
void handleEpoll(struct epoll_event *events, int max_events);
void handleEpollHangUp(struct epoll_event *event) const;

View File

@ -10,7 +10,7 @@
#include <sys/socket.h> // for recv, send
#include <unistd.h> // for close
ASocket::ASocket(int fd) : fd_(fd)
ASocket::ASocket(int fd, IOState event) : fd_(fd), ioState_(event)
{
Log::trace(LOCATION);
if (fd_ == -1)
@ -65,6 +65,34 @@ int ASocket::getFd() const noexcept
return fd_;
}
ASocket::IOState ASocket::getEvent() const noexcept
{
return ioState_;
}
bool ASocket::isDirty() const noexcept
{
return dirty_;
}
void ASocket::setIOState(IOState event)
{
if (event == ioState_)
{
return;
}
Log::debug("Processing state change for socket " + std::to_string(fd_));
dirty_ = true;
ioState_ = event;
}
void ASocket::processed()
{
Log::debug("Socket " + std::to_string(fd_) + " processed");
dirty_ = false;
}
void ASocket::setFd(int fd)
{
fd_ = fd;

View File

@ -16,8 +16,15 @@ class ASocket
CGI_SOCKET
};
enum class IOState : uint32_t
{
NONE = 0,
READ = 1 << 0,
WRITE = 1 << 1,
};
ASocket() = delete;
explicit ASocket(int fd);
explicit ASocket(int fd, IOState state = IOState::READ);
ASocket(const ASocket &other) = delete;
ASocket &operator=(const ASocket &other) = delete;
@ -28,6 +35,8 @@ class ASocket
[[nodiscard]] virtual Type getType() const noexcept = 0;
[[nodiscard]] int getFd() const noexcept;
[[nodiscard]] IOState getEvent() const noexcept;
[[nodiscard]] bool isDirty() const noexcept;
void callback() const;
void setCallback(std::function<void()> callback);
@ -35,11 +44,16 @@ class ASocket
virtual ssize_t read(void *buf, size_t len) const;
virtual ssize_t write(const void *buf, size_t len) const;
void setIOState(IOState event);
void processed();
protected:
void setNonBlocking() const;
void setFd(int fd);
private:
int fd_;
bool dirty_ = false;
IOState ioState_;
std::function<void()> callback_ = nullptr;
};

View File

@ -1,9 +1,10 @@
#include <webserv/log/Log.hpp> // for LOCATION, Log
#include <webserv/socket/ASocket.hpp> // for ASocket
#include <webserv/socket/CgiSocket.hpp>
#include <unistd.h>
CgiSocket::CgiSocket(int fd) : ASocket(fd)
CgiSocket::CgiSocket(int fd, ASocket::IOState event) : ASocket(fd, event)
{
Log::trace(LOCATION);
}
@ -13,7 +14,7 @@ ASocket::Type CgiSocket::getType() const noexcept
return ASocket::Type::CGI_SOCKET;
}
ssize_t CgiSocket::read(void *buf, size_t len) const
ssize_t CgiSocket::read(void *buf, size_t len) const
{
Log::trace(LOCATION);
ssize_t bytesRead = ::read(getFd(), buf, len);
@ -24,7 +25,7 @@ ssize_t CgiSocket::read(void *buf, size_t len) const
return bytesRead;
}
ssize_t CgiSocket::write(const void *buf, size_t len) const
ssize_t CgiSocket::write(const void *buf, size_t len) const
{
Log::trace(LOCATION);
ssize_t bytesSent = ::write(getFd(), buf, len);

View File

@ -9,10 +9,10 @@
class CgiSocket : public ASocket
{
public:
explicit CgiSocket(int fd);
explicit CgiSocket(int fd, ASocket::IOState event);
[[nodiscard]] ASocket::Type getType() const noexcept override;
ssize_t read(void *buf, size_t len) const override;
ssize_t write(const void *buf, size_t len) const override;
ssize_t read(void *buf, size_t len) const override;
ssize_t write(const void *buf, size_t len) const override;
};

View File

@ -51,7 +51,7 @@ void ServerSocket::bind(const std::string &host, const int port) const
struct sockaddr_in address{};
address.sin_family = AF_INET;
address.sin_addr.s_addr = inet_addr(host.c_str());
address.sin_port = htons(port);
address.sin_port = htons(static_cast<uint16_t>(port));
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
if (::bind(getFd(), reinterpret_cast<struct sockaddr *>(&address), sizeof(address)) < 0)

View File

@ -1,11 +1,17 @@
#include "webserv/socket/ASocket.hpp"
#include <webserv/utils/utils.hpp>
#include <cstdint>
#include <sstream>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <vector>
#include <sys/epoll.h>
namespace utils
{
size_t stoul(const std::string &str, int base)
@ -132,4 +138,19 @@ std::string implode(const std::vector<std::string> &elements, const std::string
}
return stream.str();
}
uint32_t stateToEpoll(const ASocket::IOState &event)
{
uint32_t epollEvents = 0;
using EventType = std::underlying_type_t<ASocket::IOState>;
if ((static_cast<EventType>(event) & static_cast<EventType>(ASocket::IOState::READ)) != 0U)
{
epollEvents |= EPOLLIN;
}
if ((static_cast<EventType>(event) & static_cast<EventType>(ASocket::IOState::WRITE)) != 0U)
{
epollEvents |= EPOLLOUT;
}
return epollEvents;
}
} // namespace utils

View File

@ -1,13 +1,15 @@
#pragma once
#include "webserv/socket/ASocket.hpp"
#include <cstddef> // for size_t
#include <string> // for string
#include <cstdint>
#include <string> // for string
#include <vector>
namespace utils
{
size_t stoul(const std::string &str, int base = 10);
// std::string trimSemi(const std::string &str);
std::string trim(const std::string &str, const std::string &charset = " \t\n\r");
size_t findCorrespondingClosingBrace(const std::string &str, size_t openPos);
void removeEmptyLines(std::string &str);
@ -15,4 +17,6 @@ void removeComments(std::string &str);
std::vector<std::string> split(const std::string &str, char delimiter);
std::string implode(const std::vector<std::string> &elements, const std::string &delimiter);
uint32_t stateToEpoll(const ASocket::IOState &event);
} // namespace utils