From 4cf3eac4d14ee1a34d873941aa7986950e66144d Mon Sep 17 00:00:00 2001 From: Quinten Date: Sun, 26 Oct 2025 19:18:11 +0100 Subject: [PATCH] feat: request validation --- webserv/client/Client.cpp | 36 ++++++----------- webserv/http/HttpConstants.hpp | 4 +- webserv/http/HttpRequest.cpp | 21 +++++++--- webserv/http/HttpRequest.hpp | 2 + webserv/http/RequestValidator.cpp | 67 +++++++++++++++++++++++++++++++ webserv/http/RequestValidator.hpp | 21 ++++------ webserv/router/Router.cpp | 43 ++++++++++++-------- webserv/router/Router.hpp | 5 ++- 8 files changed, 137 insertions(+), 62 deletions(-) create mode 100644 webserv/http/RequestValidator.cpp diff --git a/webserv/client/Client.cpp b/webserv/client/Client.cpp index cb66b8f..5a0c4ca 100644 --- a/webserv/client/Client.cpp +++ b/webserv/client/Client.cpp @@ -1,5 +1,4 @@ #include - #include // for CgiHandler #include // for ErrorHandler #include // for HttpHeaders @@ -91,15 +90,23 @@ void Client::request() {"body", httpRequest_->getBody()}, {"state", std::to_string(static_cast(httpRequest_->getState()))}, }); - // server_.responseReady(client_socket_->getFd()); - handler_ = router_->handleRequest(); - if (handler_ != nullptr) + + try { - handler_->handle(); + // Thoughts: if a handler isn't returned, this could because of the error handler already setting up the response + // so, maybe we don't need to throw a 500 when no handler. Because that would override the actual error response. + // How about the router, or a handler, throws an exception if something goes wrong, and we catch it here to make a 500 response? + handler_ = router_->handleRequest(); + if (handler_ != nullptr) + { + handler_->handle(); + } } - else + catch (const std::exception &e) { + Log::error("Exception during request handling: " + std::string(e.what())); ErrorHandler::createErrorResponse(500, *httpResponse_); + return; } } else @@ -112,23 +119,6 @@ void Client::request() } } -// // -// void Client::setCgiSockets(CgiSocket *cgiStdIn, CgiSocket *cgiStdOut) -// { -// // server_.add(*cgiStdIn, EPOLLOUT, this); // write -// // server_.add(*cgiStdOut, EPOLLIN, this); // read - -// sockets_[cgiStdIn->getFd()] = cgiStdIn; -// sockets_[cgiStdOut->getFd()] = cgiStdOut; -// } - -// void Client::removeCgiSocket(CgiSocket *cgiSocket) -// { -// server_.remove(*cgiSocket); // write - -// sockets_.erase(cgiSocket->getFd()); -// } - void Client::addSocket(ASocket *socket) { server_.add(*socket, this); diff --git a/webserv/http/HttpConstants.hpp b/webserv/http/HttpConstants.hpp index 5e6f6f8..325cfb7 100644 --- a/webserv/http/HttpConstants.hpp +++ b/webserv/http/HttpConstants.hpp @@ -40,6 +40,7 @@ constexpr uint16_t UNAUTHORIZED = 401; constexpr uint16_t FORBIDDEN = 403; constexpr uint16_t NOT_FOUND = 404; constexpr uint16_t METHOD_NOT_ALLOWED = 405; +constexpr uint16_t PAYLOAD_TOO_LARGE = 413; constexpr uint16_t INTERNAL_SERVER_ERROR = 500; constexpr uint16_t NOT_IMPLEMENTED = 501; constexpr uint16_t BAD_GATEWAY = 502; @@ -52,7 +53,7 @@ struct StatusCodeInfo std::string_view reason; }; -static constexpr std::array statusCodeInfos +static constexpr std::array statusCodeInfos = {{{.code = StatusCode::OK, .reason = "OK"}, {.code = StatusCode::CREATED, .reason = "Created"}, {.code = StatusCode::NO_CONTENT, .reason = "No Content"}, @@ -61,6 +62,7 @@ static constexpr std::array statusCodeInfos {.code = StatusCode::FORBIDDEN, .reason = "Forbidden"}, {.code = StatusCode::NOT_FOUND, .reason = "Not Found"}, {.code = StatusCode::METHOD_NOT_ALLOWED, .reason = "Method Not Allowed"}, + {.code = StatusCode::PAYLOAD_TOO_LARGE, .reason = "Payload Too Large"}, {.code = StatusCode::INTERNAL_SERVER_ERROR, .reason = "Internal Server Error"}, {.code = StatusCode::NOT_IMPLEMENTED, .reason = "Not Implemented"}, {.code = StatusCode::BAD_GATEWAY, .reason = "Bad Gateway"}, diff --git a/webserv/http/HttpRequest.cpp b/webserv/http/HttpRequest.cpp index 248f1a5..3a683f4 100644 --- a/webserv/http/HttpRequest.cpp +++ b/webserv/http/HttpRequest.cpp @@ -1,8 +1,9 @@ -#include +#include "webserv/config/ServerConfig.hpp" #include // for ConfigManager #include // for URI #include // for CRLF, DOUBLE_CRLF +#include #include // for Log, LOCATION #include // for stoul @@ -32,12 +33,10 @@ void HttpRequest::setState(State state) { if (state == State::Complete) { - // TODO: segfault if server does not exist - std::string hostHeader = getHeaders().getHost().value_or(""); - ServerConfig *serverConfig = ConfigManager::getInstance().getMatchingServerConfig(hostHeader); - if (hostHeader.empty() || serverConfig == nullptr) + ServerConfig *serverConfig = getServerConfig(); + if (serverConfig == nullptr) { - Log::error("No matching server config found for host: " + hostHeader); + Log::error("No matching server config found for host"); state_ = State::ParseError; return; } @@ -246,6 +245,16 @@ void HttpRequest::reset() // contentLength_ = 0; } +ServerConfig *HttpRequest::getServerConfig() const +{ + std::string hostHeader = getHeaders().getHost().value_or(""); + if (hostHeader.empty()) + { + return nullptr; + } + return ConfigManager::getInstance().getMatchingServerConfig(hostHeader); +} + const URI &HttpRequest::getUri() const noexcept { return *uri_; diff --git a/webserv/http/HttpRequest.hpp b/webserv/http/HttpRequest.hpp index 7dfd0c2..94d81ff 100644 --- a/webserv/http/HttpRequest.hpp +++ b/webserv/http/HttpRequest.hpp @@ -56,6 +56,8 @@ class HttpRequest void parseBuffer(); void parseContentLength(); + ServerConfig * getServerConfig() const; + Client *client_; State state_ = State::RequestLine; diff --git a/webserv/http/RequestValidator.cpp b/webserv/http/RequestValidator.cpp new file mode 100644 index 0000000..7ee1055 --- /dev/null +++ b/webserv/http/RequestValidator.cpp @@ -0,0 +1,67 @@ +#include "webserv/config/AConfig.hpp" + +#include + +#include +#include +#include + +RequestValidator::RequestValidator(const AConfig *config, const HttpRequest *request) : config(config), request(request) +{ +} + +std::optional RequestValidator::validate() const +{ + if (auto error = validateContentLength()) + { + return error; + } + if (auto error = validateMethod()) + { + return error; + } + return std::nullopt; // No validation errors +} + +std::optional RequestValidator::validateContentLength() const +{ + size_t bodySize = request->getBody().size(); + size_t maxBodySize = config->get("client_max_body_size").value_or(1024 * 1024); + if (bodySize > maxBodySize) // exceed server limit + { + return ValidationError{413, "Payload Too Large"}; + } + // If Content-Length header is present, validate it + if (request->getHeaders().getContentLength()) + { + size_t contentLength = *request->getHeaders().getContentLength(); + if (contentLength != bodySize) + { + return ValidationError{400, "Bad Request: Content-Length does not match body size"}; + } + } + return std::nullopt; // Content-Length is valid +} + +std::optional RequestValidator::validateMethod() const +{ + auto allowedMethodsOpt = config->get>("allowed_methods"); + std::vector allowedMethods; + + if (allowedMethodsOpt.has_value()) + { + allowedMethods = std::move(*allowedMethodsOpt); + } + else + { + allowedMethods = {"GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS"}; + } + for (const std::string &method : allowedMethods) + { + if (method == request->getMethod()) + { + return std::nullopt; // Method is allowed + } + } + return ValidationError{405, "Method Not Allowed"}; +} diff --git a/webserv/http/RequestValidator.hpp b/webserv/http/RequestValidator.hpp index 4ec1249..961a720 100644 --- a/webserv/http/RequestValidator.hpp +++ b/webserv/http/RequestValidator.hpp @@ -1,5 +1,6 @@ #pragma once +#include "webserv/config/AConfig.hpp" #include #include @@ -14,20 +15,14 @@ class RequestValidator int statusCode; std::string message; }; - - RequestValidator() = delete; - RequestValidator(const ServerConfig &config, const HttpRequest &request); - RequestValidator(const RequestValidator &other) = delete; - RequestValidator(RequestValidator &&other) = delete; - - RequestValidator &operator=(const RequestValidator &other) = delete; - RequestValidator &operator=(RequestValidator &&other) = delete; - - ~RequestValidator() = default; + RequestValidator(const AConfig *config, const HttpRequest *request); [[nodiscard]] std::optional validate() const; private: - ServerConfig &config; - HttpRequest &request; -}; \ No newline at end of file + const AConfig *config; + const HttpRequest *request; + + [[nodiscard]] std::optional validateContentLength() const; + [[nodiscard]] std::optional validateMethod() const; + }; \ No newline at end of file diff --git a/webserv/router/Router.cpp b/webserv/router/Router.cpp index 6807e4e..84c8dd5 100644 --- a/webserv/router/Router.cpp +++ b/webserv/router/Router.cpp @@ -1,4 +1,4 @@ -#include // for Router +#include "webserv/handler/ErrorHandler.hpp" #include // for Client #include // for AConfig @@ -10,6 +10,7 @@ #include // for URI #include // for HttpRequest #include // for Log, LOCATION +#include // for Router #include // for exception #include // for vector @@ -26,16 +27,16 @@ Router::Router(Client *client) : client_(client) Log::trace(LOCATION); } -bool Router::isMethodSupported(const std::string &method, const AConfig &config) noexcept -{ - const ADirective *allowedMethods = config.getDirective("allowed_methods"); - if (allowedMethods == nullptr || !allowedMethods->getValue().try_get>().has_value()) - { - return true; - } - auto methods = allowedMethods->getValue().get>(); - return std::ranges::find(methods, method) != methods.end(); -} +// bool Router::isMethodSupported(const std::string &method, const AConfig &config) noexcept +// { +// const ADirective *allowedMethods = config.getDirective("allowed_methods"); +// if (allowedMethods == nullptr || !allowedMethods->getValue().try_get>().has_value()) +// { +// return true; +// } +// auto methods = allowedMethods->getValue().get>(); +// return std::ranges::find(methods, method) != methods.end(); +// } std::unique_ptr Router::handleRequest() { @@ -50,17 +51,25 @@ std::unique_ptr Router::handleRequest() } HttpResponse &response = client_->getHttpResponse(); - const std::string &target = request.getTarget(); - static_cast(target); // Suppress unused variable warning - const std::string &method = request.getMethod(); + // const std::string &target = request.getTarget(); + // static_cast(target); // Suppress unused variable warning + // const std::string &method = request.getMethod(); const AConfig *config = request.getUri().getConfig(); - - if (!isMethodSupported(method, *config)) + auto validator = std::make_unique(config, &request); + auto error = validator->validate(); + if (error.has_value()) { + Log::warning("Request validation failed: " + error->message); + ErrorHandler::createErrorResponse(error->statusCode, response, config); return nullptr; - // return ErrorHandler::getErrorResponse(405, config); } + + // if (!isMethodSupported(method, *config)) + // { + // return nullptr; + // // return ErrorHandler::getErrorResponse(405, config); + // } if (request.getUri().isCgi()) { try diff --git a/webserv/router/Router.hpp b/webserv/router/Router.hpp index c399270..9327bcd 100644 --- a/webserv/router/Router.hpp +++ b/webserv/router/Router.hpp @@ -1,5 +1,6 @@ #pragma once +#include "webserv/http/RequestValidator.hpp" #include // for AConfig #include #include // for AHandler @@ -7,7 +8,6 @@ #include // for HttpResponse #include // for unique_ptr -#include // for string class LocationConfig; class ServerConfig; @@ -21,5 +21,6 @@ class Router private: Client *client_; - [[nodiscard]] bool isMethodSupported(const std::string &method, const AConfig &config) noexcept; + std::unique_ptr requestValidator_; + // [[nodiscard]] bool isMethodSupported(const std::string &method, const AConfig &config) noexcept; }; \ No newline at end of file