remove client when it disconnects

This commit is contained in:
Quinten 2025-09-30 17:18:09 +02:00
parent bd01a172c5
commit 1c634ddc02
7 changed files with 65 additions and 12 deletions

View File

@ -1,3 +1,4 @@
#include "webserv/http/HttpConstants.hpp"
#include <webserv/client/Client.hpp> #include <webserv/client/Client.hpp>
#include <webserv/config/ConfigManager.hpp> // for ConfigManager #include <webserv/config/ConfigManager.hpp> // for ConfigManager
#include <webserv/config/ServerConfig.hpp> // for ServerConfig #include <webserv/config/ServerConfig.hpp> // for ServerConfig
@ -18,27 +19,32 @@ Client::Client(std::unique_ptr<Socket> socket, Server &server)
: client_socket_(std::move(socket)), server_(std::ref(server)), httpRequest_(std::make_unique<HttpRequest>(this)), : client_socket_(std::move(socket)), server_(std::ref(server)), httpRequest_(std::make_unique<HttpRequest>(this)),
httpResponse_(std::make_unique<HttpResponse>()) httpResponse_(std::make_unique<HttpResponse>())
{ {
Log::trace(LOCATION);
Log::info("New client connected, fd: " + std::to_string(client_socket_->getFd())); Log::info("New client connected, fd: " + std::to_string(client_socket_->getFd()));
} }
Client::~Client() Client::~Client()
{ {
Log::trace(LOCATION);
Log::info("Client disconnected, fd: " + std::to_string(client_socket_->getFd())); Log::info("Client disconnected, fd: " + std::to_string(client_socket_->getFd()));
server_.removeFromEpoll(*client_socket_); server_.removeFromEpoll(*client_socket_);
}; };
int Client::getStatusCode() const int Client::getStatusCode() const
{ {
Log::trace(LOCATION);
return statusCode_; return statusCode_;
} }
void Client::setStatusCode(int code) void Client::setStatusCode(int code)
{ {
Log::trace(LOCATION);
statusCode_ = code; statusCode_ = code;
} }
void Client::request() void Client::request()
{ {
Log::trace(LOCATION);
char buffer[bufferSize_] = {}; // NOLINT(cppcoreguidelines-avoid-c-arrays) char buffer[bufferSize_] = {}; // NOLINT(cppcoreguidelines-avoid-c-arrays)
ssize_t bytesRead = ssize_t bytesRead =
client_socket_->recv(buffer, sizeof(buffer) - 1); // NOLINT(cppcoreguidelines-pro-bounds-array-to-pointer-decay) client_socket_->recv(buffer, sizeof(buffer) - 1); // NOLINT(cppcoreguidelines-pro-bounds-array-to-pointer-decay)
@ -49,7 +55,8 @@ void Client::request()
} }
if (bytesRead == 0) if (bytesRead == 0)
{ {
Log::info("Client disconnected, fd: " + std::to_string(client_socket_->getFd())); // TODO weird Log::info("Client closed connection, fd: " + std::to_string(client_socket_->getFd())); // TODO weird
server_.removeClient(*this); // CRITICAL: RETURN IMMEDIATELY
return; return;
} }
@ -74,7 +81,7 @@ void Client::request()
{ {
Log::warning("No matching server config found for Host: " + Log::warning("No matching server config found for Host: " +
httpRequest_->getHeaders().getHost().value_or("unknown host")); httpRequest_->getHeaders().getHost().value_or("unknown host"));
httpRequest_->setState(HttpRequest::State::ParseError); setError(Http::StatusCode::BAD_REQUEST);
} }
// Example usage, replace with actual host and port extraction from request // Example usage, replace with actual host and port extraction from request
@ -92,6 +99,7 @@ void Client::request()
bool Client::isResponseReady() const bool Client::isResponseReady() const
{ {
Log::trace(LOCATION);
// todo: poll the httpResponse_ object // todo: poll the httpResponse_ object
return httpResponse_->isComplete(); return httpResponse_->isComplete();
} }
@ -99,13 +107,23 @@ bool Client::isResponseReady() const
std::vector<uint8_t> Client::getResponse() const std::vector<uint8_t> Client::getResponse() const
{ {
Log::trace(LOCATION); Log::trace(LOCATION);
if (httpRequest_->getState() == HttpRequest::State::ParseError) Log::trace(LOCATION);
if (statusCode_ == Http::StatusCode::OK)
{ {
return ErrorHandler::getErrorResponse(Http::StatusCode::BAD_REQUEST).toBytes(); httpResponse_->setStatus(200);
httpResponse_->addHeader("Content-Type", "text/plain");
httpResponse_->appendBody("Hello, World!\n");
} }
httpResponse_->setStatus(200);
httpResponse_->addHeader("Content-Type", "text/plain");
httpResponse_->appendBody("Hello, World!\n");
return httpResponse_->toBytes(); return httpResponse_->toBytes();
} }
void Client::setError(int statusCode)
{
Log::trace(LOCATION);
statusCode_ = statusCode;
Log::debug("Setting error response with status code: " + std::to_string(statusCode));
auto errorResponse = std::make_unique<HttpResponse>(
ErrorHandler::getErrorResponse(statusCode, const_cast<ServerConfig *>(server_config_)));
httpResponse_ = std::move(errorResponse);
Log::debug("Error response set successfully");
}

View File

@ -39,6 +39,7 @@ class Client
[[nodiscard]] int getStatusCode() const; [[nodiscard]] int getStatusCode() const;
[[nodiscard]] Socket &getSocket() const { return *client_socket_; } [[nodiscard]] Socket &getSocket() const { return *client_socket_; }
void setError(int statusCode);
void setStatusCode(int code); void setStatusCode(int code);

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include "webserv/http/HttpHeaders.hpp" #include "webserv/http/HttpHeaders.hpp"
#include "webserv/log/Log.hpp"
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
@ -19,7 +20,7 @@ class HttpResponse
HttpResponse(HttpResponse &&other) noexcept = default; // Move constructor HttpResponse(HttpResponse &&other) noexcept = default; // Move constructor
HttpResponse &operator=(HttpResponse &&other) noexcept = default; // Move assignment HttpResponse &operator=(HttpResponse &&other) noexcept = default; // Move assignment
~HttpResponse() = default; ~HttpResponse() {Log::trace(LOCATION);};
void addHeader(const std::string &key, const std::string &value); void addHeader(const std::string &key, const std::string &value);

View File

@ -15,7 +15,7 @@ int main(int argc, char **argv)
std::cerr << "Usage: " << argv[0] << " <config_file_path>\n"; // NOLINT std::cerr << "Usage: " << argv[0] << " <config_file_path>\n"; // NOLINT
return 1; return 1;
} }
Log::setFileChannel("webserv.log", std::ios_base::app, Log::Level::Trace); Log::setFileChannel("webserv.log", std::ios_base::app, Log::Level::Info);
Log::setStdoutChannel(Log::Level::Info); Log::setStdoutChannel(Log::Level::Info);
Log::info("\n======================\nStarting webserv...\n======================\n"); Log::info("\n======================\nStarting webserv...\n======================\n");
ConfigManager::getInstance().init(argv[1]); // NOLINT ConfigManager::getInstance().init(argv[1]); // NOLINT

View File

@ -22,6 +22,7 @@
Server::Server(const ConfigManager &configManager) : epoll_fd_(epoll_create1(0)), configManager_(configManager) Server::Server(const ConfigManager &configManager) : epoll_fd_(epoll_create1(0)), configManager_(configManager)
{ {
Log::trace(LOCATION);
const auto &serverConfigs = configManager.getServerConfigs(); const auto &serverConfigs = configManager.getServerConfigs();
if (serverConfigs.empty()) if (serverConfigs.empty())
{ {
@ -37,6 +38,7 @@ Server::Server(const ConfigManager &configManager) : epoll_fd_(epoll_create1(0))
Server::~Server() Server::~Server()
{ {
Log::trace(LOCATION);
if (epoll_fd_ != -1) if (epoll_fd_ != -1)
{ {
close(epoll_fd_); close(epoll_fd_);
@ -45,6 +47,7 @@ Server::~Server()
void Server::start() void Server::start()
{ {
Log::trace(LOCATION);
Log::info("Starting servers..."); Log::info("Starting servers...");
// 1. Load server configurations // 1. Load server configurations
@ -63,6 +66,7 @@ void Server::start()
void Server::addToEpoll(const Socket &socket, uint32_t events) const void Server::addToEpoll(const Socket &socket, uint32_t events) const
{ {
Log::trace(LOCATION);
int fd = socket.getFd(); int fd = socket.getFd();
struct epoll_event event{}; struct epoll_event event{};
event.events = events; event.events = events;
@ -74,8 +78,18 @@ void Server::addToEpoll(const Socket &socket, uint32_t events) const
} }
} }
void Server::removeClient(const Client &client)
{
Log::trace(LOCATION);
int client_fd = client.getSocket().getFd();
clients_.erase(client_fd);
// removeFromEpoll(client.getSocket());
// close(client_fd);
}
void Server::removeFromEpoll(const Socket &socket) const void Server::removeFromEpoll(const Socket &socket) const
{ {
Log::trace(LOCATION);
int filedes = socket.getFd(); int filedes = socket.getFd();
if (epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, filedes, nullptr) == -1) if (epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, filedes, nullptr) == -1)
{ {
@ -86,6 +100,7 @@ void Server::removeFromEpoll(const Socket &socket) const
void Server::setupServerSocket(const ServerConfig &config) void Server::setupServerSocket(const ServerConfig &config)
{ {
Log::trace(LOCATION);
try try
{ {
auto host = config.getDirectiveValue<std::string>("host"); auto host = config.getDirectiveValue<std::string>("host");
@ -110,6 +125,7 @@ void Server::setupServerSocket(const ServerConfig &config)
void Server::handleConnection(struct epoll_event *event) void Server::handleConnection(struct epoll_event *event)
{ {
Log::trace(LOCATION);
Socket &listener = getListener(event->data.fd); Socket &listener = getListener(event->data.fd);
std::unique_ptr<Socket> clientSocket = listener.accept(); std::unique_ptr<Socket> clientSocket = listener.accept();
addToEpoll(*clientSocket, EPOLLIN); addToEpoll(*clientSocket, EPOLLIN);
@ -118,6 +134,7 @@ void Server::handleConnection(struct epoll_event *event)
Socket &Server::getListener(int fd) const Socket &Server::getListener(int fd) const
{ {
Log::trace(LOCATION);
for (const auto &listener : listeners_) for (const auto &listener : listeners_)
{ {
if (listener->getFd() == fd) if (listener->getFd() == fd)
@ -131,6 +148,7 @@ Socket &Server::getListener(int fd) const
Client &Server::getClient(int fd) const Client &Server::getClient(int fd) const
{ {
Log::trace(LOCATION);
auto it = clients_.find(fd); auto it = clients_.find(fd);
if (it != clients_.end()) if (it != clients_.end())
{ {
@ -142,11 +160,13 @@ Client &Server::getClient(int fd) const
const ServerConfig &Server::getConfig(const Socket &socket) const const ServerConfig &Server::getConfig(const Socket &socket) const
{ {
Log::trace(LOCATION);
return getConfig(socket.getFd()); return getConfig(socket.getFd());
} }
const ServerConfig &Server::getConfig(int fd) const const ServerConfig &Server::getConfig(int fd) const
{ {
Log::trace(LOCATION);
auto it = fdToConfig_.find(fd); auto it = fdToConfig_.find(fd);
if (it != fdToConfig_.end()) if (it != fdToConfig_.end())
{ {
@ -158,6 +178,7 @@ const ServerConfig &Server::getConfig(int fd) const
void Server::handleRequest(struct epoll_event *event) const void Server::handleRequest(struct epoll_event *event) const
{ {
Log::trace(LOCATION);
int client_fd = event->data.fd; int client_fd = event->data.fd;
Client &client = getClient(client_fd); Client &client = getClient(client_fd);
@ -166,6 +187,7 @@ void Server::handleRequest(struct epoll_event *event) const
void Server::responseReady(int client_fd) const void Server::responseReady(int client_fd) const
{ {
Log::trace(LOCATION);
Log::debug("Response ready for client fd: " + std::to_string(client_fd)); Log::debug("Response ready for client fd: " + std::to_string(client_fd));
struct epoll_event ev{}; struct epoll_event ev{};
ev.events = EPOLLOUT; ev.events = EPOLLOUT;
@ -179,6 +201,7 @@ void Server::responseReady(int client_fd) const
void Server::eventLoop() void Server::eventLoop()
{ {
Log::trace(LOCATION);
Log::info("Listening..."); Log::info("Listening...");
const int MAX_EVENTS = 10; const int MAX_EVENTS = 10;
struct epoll_event events[MAX_EVENTS]; // NOLINT struct epoll_event events[MAX_EVENTS]; // NOLINT

View File

@ -33,6 +33,7 @@ class Server
void start(); void start();
void addToEpoll(const Socket &socket, uint32_t events) const; void addToEpoll(const Socket &socket, uint32_t events) const;
void removeClient(const Client &client);
void removeFromEpoll(const Socket &socket) const; void removeFromEpoll(const Socket &socket) const;
void setupServerSocket(const ServerConfig &config); void setupServerSocket(const ServerConfig &config);
void handleConnection(struct epoll_event *event); void handleConnection(struct epoll_event *event);

View File

@ -13,6 +13,7 @@
Socket::Socket() : fd_(socket(AF_INET, SOCK_STREAM, 0)) Socket::Socket() : fd_(socket(AF_INET, SOCK_STREAM, 0))
{ {
Log::trace(LOCATION);
if (fd_ == -1) if (fd_ == -1)
{ {
Log::error("Socket creation failed"); Log::error("Socket creation failed");
@ -31,6 +32,7 @@ Socket::Socket() : fd_(socket(AF_INET, SOCK_STREAM, 0))
Socket::Socket(int fd) : fd_(fd) // NOLINT(readability-identifier-naming) Socket::Socket(int fd) : fd_(fd) // NOLINT(readability-identifier-naming)
{ {
Log::trace(LOCATION);
if (fd_ == -1) if (fd_ == -1)
{ {
Log::error("Invalid file descriptor"); Log::error("Invalid file descriptor");
@ -41,6 +43,7 @@ Socket::Socket(int fd) : fd_(fd) // NOLINT(readability-identifier-naming)
Socket::~Socket() Socket::~Socket()
{ {
Log::trace(LOCATION);
if (fd_ != -1) if (fd_ != -1)
{ {
close(fd_); close(fd_);
@ -49,6 +52,7 @@ Socket::~Socket()
void Socket::listen(int backlog) const void Socket::listen(int backlog) const
{ {
Log::trace(LOCATION);
if (::listen(fd_, backlog) < 0) if (::listen(fd_, backlog) < 0)
{ {
Log::error("Listen failed"); Log::error("Listen failed");
@ -58,7 +62,7 @@ void Socket::listen(int backlog) const
void Socket::bind(const std::string &host, const int port) const void Socket::bind(const std::string &host, const int port) const
{ {
Log::trace(LOCATION);
struct sockaddr_in address{}; struct sockaddr_in address{};
address.sin_family = AF_INET; address.sin_family = AF_INET;
address.sin_addr.s_addr = inet_addr(host.c_str()); address.sin_addr.s_addr = inet_addr(host.c_str());
@ -72,6 +76,7 @@ void Socket::bind(const std::string &host, const int port) const
std::unique_ptr<Socket> Socket::accept() const std::unique_ptr<Socket> Socket::accept() const
{ {
Log::trace(LOCATION);
int client_fd = ::accept(fd_, nullptr, nullptr); int client_fd = ::accept(fd_, nullptr, nullptr);
if (client_fd < 0) if (client_fd < 0)
{ {
@ -83,16 +88,19 @@ std::unique_ptr<Socket> Socket::accept() const
ssize_t Socket::recv(void *buf, size_t len) const ssize_t Socket::recv(void *buf, size_t len) const
{ {
Log::trace(LOCATION);
return ::recv(fd_, buf, len, 0); return ::recv(fd_, buf, len, 0);
} }
ssize_t Socket::send(const void *buf, size_t len) const ssize_t Socket::send(const void *buf, size_t len) const
{ {
Log::trace(LOCATION);
return ::send(fd_, buf, len, 0); return ::send(fd_, buf, len, 0);
} }
void Socket::setNonBlocking() const void Socket::setNonBlocking() const
{ {
Log::trace(LOCATION);
if (fcntl(fd_, F_SETFL, O_NONBLOCK) < 0) if (fcntl(fd_, F_SETFL, O_NONBLOCK) < 0)
{ {
Log::error("Failed to set non-blocking mode"); Log::error("Failed to set non-blocking mode");
@ -102,5 +110,6 @@ void Socket::setNonBlocking() const
int Socket::getFd() const int Socket::getFd() const
{ {
Log::trace(LOCATION);
return fd_; return fd_;
} }