Merge remote-tracking branch 'origin/feature/socket-refactor' into willem2

This commit is contained in:
whaffman 2025-10-21 19:03:32 +02:00
commit 7eba0b38a2
14 changed files with 161 additions and 47 deletions

View File

@ -35,6 +35,7 @@ RUN apt-get update && apt-get install -y \
nano \ nano \
# Include What You Use # Include What You Use
iwyu \ iwyu \
php-cgi \
# Clean up # Clean up
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*

View File

@ -106,26 +106,33 @@ void Client::request()
} }
} }
// // //
void Client::setCgiSockets(CgiSocket *cgiStdIn, CgiSocket *cgiStdOut, CgiSocket *cgiStdErr) // void Client::setCgiSockets(CgiSocket *cgiStdIn, CgiSocket *cgiStdOut)
// {
// // server_.add(*cgiStdIn, EPOLLOUT, this); // write
// // server_.add(*cgiStdOut, EPOLLIN, this); // read
// sockets_[cgiStdIn->getFd()] = cgiStdIn;
// sockets_[cgiStdOut->getFd()] = cgiStdOut;
// }
// void Client::removeCgiSocket(CgiSocket *cgiSocket)
// {
// server_.remove(*cgiSocket); // write
// sockets_.erase(cgiSocket->getFd());
// }
void Client::addSocket(ASocket *socket)
{ {
server_.add(*cgiStdIn, EPOLLOUT, this); // write server_.add(*socket, this);
server_.add(*cgiStdOut, EPOLLIN, this); // read sockets_[socket->getFd()] = socket;
server_.add(*cgiStdErr, EPOLLIN, this); // error
sockets_[cgiStdIn->getFd()] = cgiStdIn;
sockets_[cgiStdOut->getFd()] = cgiStdOut;
sockets_[cgiStdErr->getFd()] = cgiStdErr;
} }
void Client::removeCgiSocket(CgiSocket *cgiSocket) void Client::removeSocket(ASocket *socket)
{ {
server_.remove(*cgiSocket); // write server_.remove(*socket);
sockets_.erase(socket->getFd());
sockets_.erase(cgiSocket->getFd());
// sockets_[cgiStdIn->getFd()] = cgiStdIn;
// sockets_[cgiStdOut->getFd()] = cgiStdOut;
} }
void Client::poll() const void Client::poll() const
@ -141,7 +148,8 @@ void Client::poll() const
{ {
Log::info("Response is ready to be sent to client, fd: " + std::to_string(clientSocket_->getFd())); Log::info("Response is ready to be sent to client, fd: " + std::to_string(clientSocket_->getFd()));
clientSocket_->setCallback([this]() { respond(); }); clientSocket_->setCallback([this]() { respond(); });
server_.writable(clientSocket_->getFd()); // server_.writable(clientSocket_->getFd());
clientSocket_->setIOState(ASocket::IOState::WRITE);
} }
} }

View File

@ -43,8 +43,11 @@ class Client
[[nodiscard]] ASocket &getSocket(int fd = -1) const; [[nodiscard]] ASocket &getSocket(int fd = -1) const;
// void setStatusCode(int code); // void setStatusCode(int code);
void setCgiSockets(CgiSocket *cgiStdIn, CgiSocket *cgiStdOut, CgiSocket *cgiStdErr); // void setCgiSockets(CgiSocket *cgiStdIn, CgiSocket *cgiStdOut);
void removeCgiSocket(CgiSocket *cgiSocket); // void removeCgiSocket(CgiSocket *cgiSocket);
void addSocket(ASocket *socket);
void removeSocket(ASocket *socket);
[[nodiscard]] HttpRequest &getHttpRequest() const noexcept; [[nodiscard]] HttpRequest &getHttpRequest() const noexcept;
[[nodiscard]] HttpResponse &getHttpResponse() const noexcept; [[nodiscard]] HttpResponse &getHttpResponse() const noexcept;

View File

@ -34,7 +34,7 @@ void CgiHandler::write()
if (request_.getBody().empty()) if (request_.getBody().empty())
{ {
Log::debug("No body to write to CGI stdin, fd: " + std::to_string(cgiStdIn_->getFd())); Log::debug("No body to write to CGI stdin, fd: " + std::to_string(cgiStdIn_->getFd()));
request_.getClient().removeCgiSocket(cgiStdIn_.get()); request_.getClient().removeSocket(cgiStdIn_.get());
cgiStdIn_ = nullptr; cgiStdIn_ = nullptr;
return; return;
} }
@ -48,7 +48,7 @@ void CgiHandler::write()
Log::debug("Wrote " + std::to_string(bytesWritten) Log::debug("Wrote " + std::to_string(bytesWritten)
+ " bytes to CGI stdin, fd: " + std::to_string(cgiStdIn_->getFd())); + " bytes to CGI stdin, fd: " + std::to_string(cgiStdIn_->getFd()));
} }
request_.getClient().removeCgiSocket(cgiStdIn_.get()); request_.getClient().removeSocket(cgiStdIn_.get());
cgiStdIn_ = nullptr; cgiStdIn_ = nullptr;
} }
@ -70,7 +70,7 @@ void CgiHandler::read()
else if (bytesRead == 0) else if (bytesRead == 0)
{ {
Log::info("CGI process closed stdout, fd: " + std::to_string(cgiStdOut_->getFd())); Log::info("CGI process closed stdout, fd: " + std::to_string(cgiStdOut_->getFd()));
request_.getClient().removeCgiSocket(cgiStdOut_.get()); request_.getClient().removeSocket(cgiStdOut_.get());
cgiStdOut_ = nullptr; cgiStdOut_ = nullptr;
parseCgiOutput(); parseCgiOutput();
return; return;
@ -124,9 +124,8 @@ void CgiHandler::setCgiSockets(std::unique_ptr<CgiSocket> cgiStdIn, std::unique_
cgiStdIn_ = std::move(cgiStdIn); cgiStdIn_ = std::move(cgiStdIn);
cgiStdErr_ = std::move(cgiStdErr); cgiStdErr_ = std::move(cgiStdErr);
request_.getClient().setCgiSockets(cgiStdIn_.get(), cgiStdOut_.get(), cgiStdErr_.get()); // write request_.getClient().addSocket(cgiStdIn_.get());
request_.getClient().addSocket(cgiStdOut_.get());
// TODO add to handler
} }
void CgiHandler::wait() noexcept void CgiHandler::wait() noexcept

View File

@ -82,10 +82,8 @@ void CgiProcess::spawn()
else else
{ {
// Parent process // Parent process
auto cgiStdIn = std::make_unique<CgiSocket>(pipeStdin[1]); auto cgiStdIn = std::make_unique<CgiSocket>(pipeStdin[1], ASocket::IOState::WRITE);
auto cgiStdOut = std::make_unique<CgiSocket>(pipeStdout[0]); auto cgiStdOut = std::make_unique<CgiSocket>(pipeStdout[0], ASocket::IOState::READ);
auto cgiStdErr = std::make_unique<CgiSocket>(pipeStderr[0]);
close(pipeStdin[0]); close(pipeStdin[0]);
close(pipeStdout[1]); close(pipeStdout[1]);
close(pipeStderr[1]); close(pipeStderr[1]);

View File

@ -1,3 +1,5 @@
#include "webserv/utils/utils.hpp"
#include <webserv/client/Client.hpp> // for Client #include <webserv/client/Client.hpp> // for Client
#include <webserv/config/ConfigManager.hpp> // for ConfigManager #include <webserv/config/ConfigManager.hpp> // for ConfigManager
#include <webserv/config/ServerConfig.hpp> // for ServerConfig #include <webserv/config/ServerConfig.hpp> // for ServerConfig
@ -59,7 +61,7 @@ Server::~Server()
} }
} }
void Server::add(const ASocket &socket, uint32_t events, Client *client) void Server::add(ASocket &socket, Client *client)
{ {
if (socket.getType() != ASocket::Type::SERVER_SOCKET && client == nullptr) if (socket.getType() != ASocket::Type::SERVER_SOCKET && client == nullptr)
{ {
@ -69,7 +71,7 @@ void Server::add(const ASocket &socket, uint32_t events, Client *client)
Log::trace(LOCATION); Log::trace(LOCATION);
int fd = socket.getFd(); int fd = socket.getFd();
struct epoll_event event{}; struct epoll_event event{};
event.events = events; event.events = utils::stateToEpoll(socket.getEvent());
event.data.fd = fd; event.data.fd = fd;
if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &event) == -1) if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &event) == -1)
{ {
@ -78,9 +80,10 @@ void Server::add(const ASocket &socket, uint32_t events, Client *client)
} }
Log::debug("Socket added to epoll, fd: " + std::to_string(fd)); Log::debug("Socket added to epoll, fd: " + std::to_string(fd));
socketToClient_[fd] = client; socketToClient_[fd] = client;
sockets_.insert(&socket);
} }
void Server::remove(const ASocket &socket) void Server::remove(ASocket &socket)
{ {
Log::trace(LOCATION); Log::trace(LOCATION);
int fd = socket.getFd(); int fd = socket.getFd();
@ -90,6 +93,7 @@ void Server::remove(const ASocket &socket)
throw std::runtime_error("epoll_ctl DEL failed"); throw std::runtime_error("epoll_ctl DEL failed");
} }
socketToClient_.erase(fd); socketToClient_.erase(fd);
sockets_.erase(&socket);
} }
void Server::disconnect(const Client &client) void Server::disconnect(const Client &client)
@ -112,7 +116,7 @@ void Server::setupServerSocket(const ServerConfig &config)
serverSocket->listen(SOMAXCONN); serverSocket->listen(SOMAXCONN);
int server_fd = serverSocket->getFd(); int server_fd = serverSocket->getFd();
add(*serverSocket, EPOLLIN); add(*serverSocket);
listeners_.push_back(std::move(serverSocket)); listeners_.push_back(std::move(serverSocket));
listener_fds_.insert(server_fd); listener_fds_.insert(server_fd);
@ -131,7 +135,7 @@ void Server::handleConnection(struct epoll_event *event)
std::unique_ptr<ClientSocket> clientSocket = listener.accept(); std::unique_ptr<ClientSocket> clientSocket = listener.accept();
auto client = std::make_unique<Client>(std::move(clientSocket), *this); auto client = std::make_unique<Client>(std::move(clientSocket), *this);
add(client->getSocket(), EPOLLIN, client.get()); add(client->getSocket(), client.get());
clients_.emplace_back(std::move(client)); clients_.emplace_back(std::move(client));
} }
@ -182,6 +186,23 @@ void Server::writable(int client_fd) const
} }
} }
void Server::update(const ASocket &socket) const
{
Log::trace(LOCATION);
int socketFd = socket.getFd();
uint32_t events = utils::stateToEpoll(socket.getEvent());
Log::debug("Socket (" + std::to_string(socket.getFd()) + ") is being updated");
struct epoll_event evt{};
evt.events = events;
evt.data.fd = socketFd;
if (epoll_ctl(epoll_fd_, EPOLL_CTL_MOD, socketFd, &evt) == -1)
{
Log::error("epoll_ctl MOD failed for fd: " + std::to_string(socketFd));
throw std::runtime_error("epoll_ctl MOD failed");
}
}
void Server::handleResponse(struct epoll_event *event) const void Server::handleResponse(struct epoll_event *event) const
{ {
int socket_fd = event->data.fd; int socket_fd = event->data.fd;
@ -254,6 +275,18 @@ void Server::pollClients() const
} }
} }
void Server::pollSockets()
{
for (auto *socket : sockets_)
{
if (socket->isDirty())
{
update(*socket);
socket->processed();
}
}
}
void Server::run() void Server::run()
{ {
Log::trace(LOCATION); Log::trace(LOCATION);
@ -262,6 +295,7 @@ void Server::run()
struct epoll_event events[MAX_EVENTS]; // NOLINT struct epoll_event events[MAX_EVENTS]; // NOLINT
while (true) while (true)
{ {
pollSockets();
pollClients(); pollClients();
handleEpoll(events, MAX_EVENTS); handleEpoll(events, MAX_EVENTS);
} }

View File

@ -33,8 +33,9 @@ class Server
~Server(); ~Server();
void run(); void run();
void add(const ASocket &socket, uint32_t events, Client *client = nullptr); void add(ASocket &socket, Client *client = nullptr);
void remove(const ASocket &socket); void remove(ASocket &socket);
void update(const ASocket &socket) const;
void disconnect(const Client &client); void disconnect(const Client &client);
void writable(int client_fd) const; void writable(int client_fd) const;
@ -48,8 +49,10 @@ class Server
std::set<int> listener_fds_; std::set<int> listener_fds_;
std::vector<std::unique_ptr<Client>> clients_; std::vector<std::unique_ptr<Client>> clients_;
std::unordered_map<int, Client *> socketToClient_; std::unordered_map<int, Client *> socketToClient_;
std::set<ASocket *> sockets_;
void pollClients() const; void pollClients() const;
void pollSockets();
void handleEpoll(struct epoll_event *events, int max_events); void handleEpoll(struct epoll_event *events, int max_events);
void handleEpollHangUp(struct epoll_event *event) const; void handleEpollHangUp(struct epoll_event *event) const;

View File

@ -10,7 +10,7 @@
#include <sys/socket.h> // for recv, send #include <sys/socket.h> // for recv, send
#include <unistd.h> // for close #include <unistd.h> // for close
ASocket::ASocket(int fd) : fd_(fd) ASocket::ASocket(int fd, IOState event) : fd_(fd), ioState_(event)
{ {
Log::trace(LOCATION); Log::trace(LOCATION);
if (fd_ == -1) if (fd_ == -1)
@ -65,6 +65,34 @@ int ASocket::getFd() const noexcept
return fd_; return fd_;
} }
ASocket::IOState ASocket::getEvent() const noexcept
{
return ioState_;
}
bool ASocket::isDirty() const noexcept
{
return dirty_;
}
void ASocket::setIOState(IOState event)
{
if (event == ioState_)
{
return;
}
Log::debug("Processing state change for socket " + std::to_string(fd_));
dirty_ = true;
ioState_ = event;
}
void ASocket::processed()
{
Log::debug("Socket " + std::to_string(fd_) + " processed");
dirty_ = false;
}
void ASocket::setFd(int fd) void ASocket::setFd(int fd)
{ {
fd_ = fd; fd_ = fd;

View File

@ -16,8 +16,15 @@ class ASocket
CGI_SOCKET CGI_SOCKET
}; };
enum class IOState : uint32_t
{
NONE = 0,
READ = 1 << 0,
WRITE = 1 << 1,
};
ASocket() = delete; ASocket() = delete;
explicit ASocket(int fd); explicit ASocket(int fd, IOState state = IOState::READ);
ASocket(const ASocket &other) = delete; ASocket(const ASocket &other) = delete;
ASocket &operator=(const ASocket &other) = delete; ASocket &operator=(const ASocket &other) = delete;
@ -28,6 +35,8 @@ class ASocket
[[nodiscard]] virtual Type getType() const noexcept = 0; [[nodiscard]] virtual Type getType() const noexcept = 0;
[[nodiscard]] int getFd() const noexcept; [[nodiscard]] int getFd() const noexcept;
[[nodiscard]] IOState getEvent() const noexcept;
[[nodiscard]] bool isDirty() const noexcept;
void callback() const; void callback() const;
void setCallback(std::function<void()> callback); void setCallback(std::function<void()> callback);
@ -35,11 +44,16 @@ class ASocket
virtual ssize_t read(void *buf, size_t len) const; virtual ssize_t read(void *buf, size_t len) const;
virtual ssize_t write(const void *buf, size_t len) const; virtual ssize_t write(const void *buf, size_t len) const;
void setIOState(IOState event);
void processed();
protected: protected:
void setNonBlocking() const; void setNonBlocking() const;
void setFd(int fd); void setFd(int fd);
private: private:
int fd_; int fd_;
bool dirty_ = false;
IOState ioState_;
std::function<void()> callback_ = nullptr; std::function<void()> callback_ = nullptr;
}; };

View File

@ -1,9 +1,10 @@
#include <webserv/log/Log.hpp> // for LOCATION, Log #include <webserv/log/Log.hpp> // for LOCATION, Log
#include <webserv/socket/ASocket.hpp> // for ASocket #include <webserv/socket/ASocket.hpp> // for ASocket
#include <webserv/socket/CgiSocket.hpp> #include <webserv/socket/CgiSocket.hpp>
#include <unistd.h> #include <unistd.h>
CgiSocket::CgiSocket(int fd) : ASocket(fd) CgiSocket::CgiSocket(int fd, ASocket::IOState event) : ASocket(fd, event)
{ {
Log::trace(LOCATION); Log::trace(LOCATION);
} }

View File

@ -9,7 +9,7 @@
class CgiSocket : public ASocket class CgiSocket : public ASocket
{ {
public: public:
explicit CgiSocket(int fd); explicit CgiSocket(int fd, ASocket::IOState event);
[[nodiscard]] ASocket::Type getType() const noexcept override; [[nodiscard]] ASocket::Type getType() const noexcept override;

View File

@ -51,7 +51,7 @@ void ServerSocket::bind(const std::string &host, const int port) const
struct sockaddr_in address{}; struct sockaddr_in address{};
address.sin_family = AF_INET; address.sin_family = AF_INET;
address.sin_addr.s_addr = inet_addr(host.c_str()); address.sin_addr.s_addr = inet_addr(host.c_str());
address.sin_port = htons(port); address.sin_port = htons(static_cast<uint16_t>(port));
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
if (::bind(getFd(), reinterpret_cast<struct sockaddr *>(&address), sizeof(address)) < 0) if (::bind(getFd(), reinterpret_cast<struct sockaddr *>(&address), sizeof(address)) < 0)

View File

@ -1,11 +1,17 @@
#include "webserv/socket/ASocket.hpp"
#include <webserv/utils/utils.hpp> #include <webserv/utils/utils.hpp>
#include <cstdint>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <type_traits>
#include <vector> #include <vector>
#include <sys/epoll.h>
namespace utils namespace utils
{ {
size_t stoul(const std::string &str, int base) size_t stoul(const std::string &str, int base)
@ -132,4 +138,19 @@ std::string implode(const std::vector<std::string> &elements, const std::string
} }
return stream.str(); return stream.str();
} }
uint32_t stateToEpoll(const ASocket::IOState &event)
{
uint32_t epollEvents = 0;
using EventType = std::underlying_type_t<ASocket::IOState>;
if ((static_cast<EventType>(event) & static_cast<EventType>(ASocket::IOState::READ)) != 0U)
{
epollEvents |= EPOLLIN;
}
if ((static_cast<EventType>(event) & static_cast<EventType>(ASocket::IOState::WRITE)) != 0U)
{
epollEvents |= EPOLLOUT;
}
return epollEvents;
}
} // namespace utils } // namespace utils

View File

@ -1,13 +1,15 @@
#pragma once #pragma once
#include "webserv/socket/ASocket.hpp"
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <cstdint>
#include <string> // for string #include <string> // for string
#include <vector> #include <vector>
namespace utils namespace utils
{ {
size_t stoul(const std::string &str, int base = 10); size_t stoul(const std::string &str, int base = 10);
// std::string trimSemi(const std::string &str);
std::string trim(const std::string &str, const std::string &charset = " \t\n\r"); std::string trim(const std::string &str, const std::string &charset = " \t\n\r");
size_t findCorrespondingClosingBrace(const std::string &str, size_t openPos); size_t findCorrespondingClosingBrace(const std::string &str, size_t openPos);
void removeEmptyLines(std::string &str); void removeEmptyLines(std::string &str);
@ -15,4 +17,6 @@ void removeComments(std::string &str);
std::vector<std::string> split(const std::string &str, char delimiter); std::vector<std::string> split(const std::string &str, char delimiter);
std::string implode(const std::vector<std::string> &elements, const std::string &delimiter); std::string implode(const std::vector<std::string> &elements, const std::string &delimiter);
uint32_t stateToEpoll(const ASocket::IOState &event);
} // namespace utils } // namespace utils