feat: request validation

This commit is contained in:
Quinten 2025-10-26 19:18:11 +01:00
parent d98fb98c52
commit 4cf3eac4d1
8 changed files with 137 additions and 62 deletions

View File

@ -1,5 +1,4 @@
#include <webserv/client/Client.hpp> #include <webserv/client/Client.hpp>
#include <webserv/handler/CgiHandler.hpp> // for CgiHandler #include <webserv/handler/CgiHandler.hpp> // for CgiHandler
#include <webserv/handler/ErrorHandler.hpp> // for ErrorHandler #include <webserv/handler/ErrorHandler.hpp> // for ErrorHandler
#include <webserv/http/HttpHeaders.hpp> // for HttpHeaders #include <webserv/http/HttpHeaders.hpp> // for HttpHeaders
@ -91,15 +90,23 @@ void Client::request()
{"body", httpRequest_->getBody()}, {"body", httpRequest_->getBody()},
{"state", std::to_string(static_cast<uint8_t>(httpRequest_->getState()))}, {"state", std::to_string(static_cast<uint8_t>(httpRequest_->getState()))},
}); });
// server_.responseReady(client_socket_->getFd());
try
{
// Thoughts: if a handler isn't returned, this could because of the error handler already setting up the response
// so, maybe we don't need to throw a 500 when no handler. Because that would override the actual error response.
// How about the router, or a handler, throws an exception if something goes wrong, and we catch it here to make a 500 response?
handler_ = router_->handleRequest(); handler_ = router_->handleRequest();
if (handler_ != nullptr) if (handler_ != nullptr)
{ {
handler_->handle(); handler_->handle();
} }
else }
catch (const std::exception &e)
{ {
Log::error("Exception during request handling: " + std::string(e.what()));
ErrorHandler::createErrorResponse(500, *httpResponse_); ErrorHandler::createErrorResponse(500, *httpResponse_);
return;
} }
} }
else else
@ -112,23 +119,6 @@ void Client::request()
} }
} }
// //
// void Client::setCgiSockets(CgiSocket *cgiStdIn, CgiSocket *cgiStdOut)
// {
// // server_.add(*cgiStdIn, EPOLLOUT, this); // write
// // server_.add(*cgiStdOut, EPOLLIN, this); // read
// sockets_[cgiStdIn->getFd()] = cgiStdIn;
// sockets_[cgiStdOut->getFd()] = cgiStdOut;
// }
// void Client::removeCgiSocket(CgiSocket *cgiSocket)
// {
// server_.remove(*cgiSocket); // write
// sockets_.erase(cgiSocket->getFd());
// }
void Client::addSocket(ASocket *socket) void Client::addSocket(ASocket *socket)
{ {
server_.add(*socket, this); server_.add(*socket, this);

View File

@ -40,6 +40,7 @@ constexpr uint16_t UNAUTHORIZED = 401;
constexpr uint16_t FORBIDDEN = 403; constexpr uint16_t FORBIDDEN = 403;
constexpr uint16_t NOT_FOUND = 404; constexpr uint16_t NOT_FOUND = 404;
constexpr uint16_t METHOD_NOT_ALLOWED = 405; constexpr uint16_t METHOD_NOT_ALLOWED = 405;
constexpr uint16_t PAYLOAD_TOO_LARGE = 413;
constexpr uint16_t INTERNAL_SERVER_ERROR = 500; constexpr uint16_t INTERNAL_SERVER_ERROR = 500;
constexpr uint16_t NOT_IMPLEMENTED = 501; constexpr uint16_t NOT_IMPLEMENTED = 501;
constexpr uint16_t BAD_GATEWAY = 502; constexpr uint16_t BAD_GATEWAY = 502;
@ -52,7 +53,7 @@ struct StatusCodeInfo
std::string_view reason; std::string_view reason;
}; };
static constexpr std::array<StatusCodeInfo, 12> statusCodeInfos static constexpr std::array<StatusCodeInfo, 13> statusCodeInfos
= {{{.code = StatusCode::OK, .reason = "OK"}, = {{{.code = StatusCode::OK, .reason = "OK"},
{.code = StatusCode::CREATED, .reason = "Created"}, {.code = StatusCode::CREATED, .reason = "Created"},
{.code = StatusCode::NO_CONTENT, .reason = "No Content"}, {.code = StatusCode::NO_CONTENT, .reason = "No Content"},
@ -61,6 +62,7 @@ static constexpr std::array<StatusCodeInfo, 12> statusCodeInfos
{.code = StatusCode::FORBIDDEN, .reason = "Forbidden"}, {.code = StatusCode::FORBIDDEN, .reason = "Forbidden"},
{.code = StatusCode::NOT_FOUND, .reason = "Not Found"}, {.code = StatusCode::NOT_FOUND, .reason = "Not Found"},
{.code = StatusCode::METHOD_NOT_ALLOWED, .reason = "Method Not Allowed"}, {.code = StatusCode::METHOD_NOT_ALLOWED, .reason = "Method Not Allowed"},
{.code = StatusCode::PAYLOAD_TOO_LARGE, .reason = "Payload Too Large"},
{.code = StatusCode::INTERNAL_SERVER_ERROR, .reason = "Internal Server Error"}, {.code = StatusCode::INTERNAL_SERVER_ERROR, .reason = "Internal Server Error"},
{.code = StatusCode::NOT_IMPLEMENTED, .reason = "Not Implemented"}, {.code = StatusCode::NOT_IMPLEMENTED, .reason = "Not Implemented"},
{.code = StatusCode::BAD_GATEWAY, .reason = "Bad Gateway"}, {.code = StatusCode::BAD_GATEWAY, .reason = "Bad Gateway"},

View File

@ -1,8 +1,9 @@
#include <webserv/http/HttpRequest.hpp> #include "webserv/config/ServerConfig.hpp"
#include <webserv/config/ConfigManager.hpp> // for ConfigManager #include <webserv/config/ConfigManager.hpp> // for ConfigManager
#include <webserv/handler/URI.hpp> // for URI #include <webserv/handler/URI.hpp> // for URI
#include <webserv/http/HttpConstants.hpp> // for CRLF, DOUBLE_CRLF #include <webserv/http/HttpConstants.hpp> // for CRLF, DOUBLE_CRLF
#include <webserv/http/HttpRequest.hpp>
#include <webserv/log/Log.hpp> // for Log, LOCATION #include <webserv/log/Log.hpp> // for Log, LOCATION
#include <webserv/utils/utils.hpp> // for stoul #include <webserv/utils/utils.hpp> // for stoul
@ -32,12 +33,10 @@ void HttpRequest::setState(State state)
{ {
if (state == State::Complete) if (state == State::Complete)
{ {
// TODO: segfault if server does not exist ServerConfig *serverConfig = getServerConfig();
std::string hostHeader = getHeaders().getHost().value_or(""); if (serverConfig == nullptr)
ServerConfig *serverConfig = ConfigManager::getInstance().getMatchingServerConfig(hostHeader);
if (hostHeader.empty() || serverConfig == nullptr)
{ {
Log::error("No matching server config found for host: " + hostHeader); Log::error("No matching server config found for host");
state_ = State::ParseError; state_ = State::ParseError;
return; return;
} }
@ -246,6 +245,16 @@ void HttpRequest::reset()
// contentLength_ = 0; // contentLength_ = 0;
} }
ServerConfig *HttpRequest::getServerConfig() const
{
std::string hostHeader = getHeaders().getHost().value_or("");
if (hostHeader.empty())
{
return nullptr;
}
return ConfigManager::getInstance().getMatchingServerConfig(hostHeader);
}
const URI &HttpRequest::getUri() const noexcept const URI &HttpRequest::getUri() const noexcept
{ {
return *uri_; return *uri_;

View File

@ -56,6 +56,8 @@ class HttpRequest
void parseBuffer(); void parseBuffer();
void parseContentLength(); void parseContentLength();
ServerConfig * getServerConfig() const;
Client *client_; Client *client_;
State state_ = State::RequestLine; State state_ = State::RequestLine;

View File

@ -0,0 +1,67 @@
#include "webserv/config/AConfig.hpp"
#include <webserv/http/RequestValidator.hpp>
#include <cstddef>
#include <string>
#include <vector>
RequestValidator::RequestValidator(const AConfig *config, const HttpRequest *request) : config(config), request(request)
{
}
std::optional<RequestValidator::ValidationError> RequestValidator::validate() const
{
if (auto error = validateContentLength())
{
return error;
}
if (auto error = validateMethod())
{
return error;
}
return std::nullopt; // No validation errors
}
std::optional<RequestValidator::ValidationError> RequestValidator::validateContentLength() const
{
size_t bodySize = request->getBody().size();
size_t maxBodySize = config->get<size_t>("client_max_body_size").value_or(1024 * 1024);
if (bodySize > maxBodySize) // exceed server limit
{
return ValidationError{413, "Payload Too Large"};
}
// If Content-Length header is present, validate it
if (request->getHeaders().getContentLength())
{
size_t contentLength = *request->getHeaders().getContentLength();
if (contentLength != bodySize)
{
return ValidationError{400, "Bad Request: Content-Length does not match body size"};
}
}
return std::nullopt; // Content-Length is valid
}
std::optional<RequestValidator::ValidationError> RequestValidator::validateMethod() const
{
auto allowedMethodsOpt = config->get<std::vector<std::string>>("allowed_methods");
std::vector<std::string> allowedMethods;
if (allowedMethodsOpt.has_value())
{
allowedMethods = std::move(*allowedMethodsOpt);
}
else
{
allowedMethods = {"GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS"};
}
for (const std::string &method : allowedMethods)
{
if (method == request->getMethod())
{
return std::nullopt; // Method is allowed
}
}
return ValidationError{405, "Method Not Allowed"};
}

View File

@ -1,5 +1,6 @@
#pragma once #pragma once
#include "webserv/config/AConfig.hpp"
#include <webserv/config/ServerConfig.hpp> #include <webserv/config/ServerConfig.hpp>
#include <webserv/http/HttpRequest.hpp> #include <webserv/http/HttpRequest.hpp>
@ -14,20 +15,14 @@ class RequestValidator
int statusCode; int statusCode;
std::string message; std::string message;
}; };
RequestValidator(const AConfig *config, const HttpRequest *request);
RequestValidator() = delete;
RequestValidator(const ServerConfig &config, const HttpRequest &request);
RequestValidator(const RequestValidator &other) = delete;
RequestValidator(RequestValidator &&other) = delete;
RequestValidator &operator=(const RequestValidator &other) = delete;
RequestValidator &operator=(RequestValidator &&other) = delete;
~RequestValidator() = default;
[[nodiscard]] std::optional<ValidationError> validate() const; [[nodiscard]] std::optional<ValidationError> validate() const;
private: private:
ServerConfig &config; const AConfig *config;
HttpRequest &request; const HttpRequest *request;
};
[[nodiscard]] std::optional<ValidationError> validateContentLength() const;
[[nodiscard]] std::optional<ValidationError> validateMethod() const;
};

View File

@ -1,4 +1,4 @@
#include <webserv/router/Router.hpp> // for Router #include "webserv/handler/ErrorHandler.hpp"
#include <webserv/client/Client.hpp> // for Client #include <webserv/client/Client.hpp> // for Client
#include <webserv/config/AConfig.hpp> // for AConfig #include <webserv/config/AConfig.hpp> // for AConfig
@ -10,6 +10,7 @@
#include <webserv/handler/URI.hpp> // for URI #include <webserv/handler/URI.hpp> // for URI
#include <webserv/http/HttpRequest.hpp> // for HttpRequest #include <webserv/http/HttpRequest.hpp> // for HttpRequest
#include <webserv/log/Log.hpp> // for Log, LOCATION #include <webserv/log/Log.hpp> // for Log, LOCATION
#include <webserv/router/Router.hpp> // for Router
#include <exception> // for exception #include <exception> // for exception
#include <format> // for vector #include <format> // for vector
@ -26,16 +27,16 @@ Router::Router(Client *client) : client_(client)
Log::trace(LOCATION); Log::trace(LOCATION);
} }
bool Router::isMethodSupported(const std::string &method, const AConfig &config) noexcept // bool Router::isMethodSupported(const std::string &method, const AConfig &config) noexcept
{ // {
const ADirective *allowedMethods = config.getDirective("allowed_methods"); // const ADirective *allowedMethods = config.getDirective("allowed_methods");
if (allowedMethods == nullptr || !allowedMethods->getValue().try_get<std::vector<std::string>>().has_value()) // if (allowedMethods == nullptr || !allowedMethods->getValue().try_get<std::vector<std::string>>().has_value())
{ // {
return true; // return true;
} // }
auto methods = allowedMethods->getValue().get<std::vector<std::string>>(); // auto methods = allowedMethods->getValue().get<std::vector<std::string>>();
return std::ranges::find(methods, method) != methods.end(); // return std::ranges::find(methods, method) != methods.end();
} // }
std::unique_ptr<AHandler> Router::handleRequest() std::unique_ptr<AHandler> Router::handleRequest()
{ {
@ -50,17 +51,25 @@ std::unique_ptr<AHandler> Router::handleRequest()
} }
HttpResponse &response = client_->getHttpResponse(); HttpResponse &response = client_->getHttpResponse();
const std::string &target = request.getTarget(); // const std::string &target = request.getTarget();
static_cast<void>(target); // Suppress unused variable warning // static_cast<void>(target); // Suppress unused variable warning
const std::string &method = request.getMethod(); // const std::string &method = request.getMethod();
const AConfig *config = request.getUri().getConfig(); const AConfig *config = request.getUri().getConfig();
auto validator = std::make_unique<RequestValidator>(config, &request);
if (!isMethodSupported(method, *config)) auto error = validator->validate();
if (error.has_value())
{ {
Log::warning("Request validation failed: " + error->message);
ErrorHandler::createErrorResponse(error->statusCode, response, config);
return nullptr; return nullptr;
// return ErrorHandler::getErrorResponse(405, config);
} }
// if (!isMethodSupported(method, *config))
// {
// return nullptr;
// // return ErrorHandler::getErrorResponse(405, config);
// }
if (request.getUri().isCgi()) if (request.getUri().isCgi())
{ {
try try

View File

@ -1,5 +1,6 @@
#pragma once #pragma once
#include "webserv/http/RequestValidator.hpp"
#include <webserv/config/AConfig.hpp> // for AConfig #include <webserv/config/AConfig.hpp> // for AConfig
#include <webserv/config/LocationConfig.hpp> #include <webserv/config/LocationConfig.hpp>
#include <webserv/handler/AHandler.hpp> // for AHandler #include <webserv/handler/AHandler.hpp> // for AHandler
@ -7,7 +8,6 @@
#include <webserv/http/HttpResponse.hpp> // for HttpResponse #include <webserv/http/HttpResponse.hpp> // for HttpResponse
#include <memory> // for unique_ptr #include <memory> // for unique_ptr
#include <string> // for string
class LocationConfig; class LocationConfig;
class ServerConfig; class ServerConfig;
@ -21,5 +21,6 @@ class Router
private: private:
Client *client_; Client *client_;
[[nodiscard]] bool isMethodSupported(const std::string &method, const AConfig &config) noexcept; std::unique_ptr<RequestValidator> requestValidator_;
// [[nodiscard]] bool isMethodSupported(const std::string &method, const AConfig &config) noexcept;
}; };