feat: request validation

This commit is contained in:
Quinten 2025-10-26 19:18:11 +01:00
parent d98fb98c52
commit 4cf3eac4d1
8 changed files with 137 additions and 62 deletions

View File

@ -1,5 +1,4 @@
#include <webserv/client/Client.hpp>
#include <webserv/handler/CgiHandler.hpp> // for CgiHandler
#include <webserv/handler/ErrorHandler.hpp> // for ErrorHandler
#include <webserv/http/HttpHeaders.hpp> // for HttpHeaders
@ -91,15 +90,23 @@ void Client::request()
{"body", httpRequest_->getBody()},
{"state", std::to_string(static_cast<uint8_t>(httpRequest_->getState()))},
});
// server_.responseReady(client_socket_->getFd());
handler_ = router_->handleRequest();
if (handler_ != nullptr)
try
{
handler_->handle();
// Thoughts: if a handler isn't returned, this could because of the error handler already setting up the response
// so, maybe we don't need to throw a 500 when no handler. Because that would override the actual error response.
// How about the router, or a handler, throws an exception if something goes wrong, and we catch it here to make a 500 response?
handler_ = router_->handleRequest();
if (handler_ != nullptr)
{
handler_->handle();
}
}
else
catch (const std::exception &e)
{
Log::error("Exception during request handling: " + std::string(e.what()));
ErrorHandler::createErrorResponse(500, *httpResponse_);
return;
}
}
else
@ -112,23 +119,6 @@ void Client::request()
}
}
// //
// void Client::setCgiSockets(CgiSocket *cgiStdIn, CgiSocket *cgiStdOut)
// {
// // server_.add(*cgiStdIn, EPOLLOUT, this); // write
// // server_.add(*cgiStdOut, EPOLLIN, this); // read
// sockets_[cgiStdIn->getFd()] = cgiStdIn;
// sockets_[cgiStdOut->getFd()] = cgiStdOut;
// }
// void Client::removeCgiSocket(CgiSocket *cgiSocket)
// {
// server_.remove(*cgiSocket); // write
// sockets_.erase(cgiSocket->getFd());
// }
void Client::addSocket(ASocket *socket)
{
server_.add(*socket, this);

View File

@ -40,6 +40,7 @@ constexpr uint16_t UNAUTHORIZED = 401;
constexpr uint16_t FORBIDDEN = 403;
constexpr uint16_t NOT_FOUND = 404;
constexpr uint16_t METHOD_NOT_ALLOWED = 405;
constexpr uint16_t PAYLOAD_TOO_LARGE = 413;
constexpr uint16_t INTERNAL_SERVER_ERROR = 500;
constexpr uint16_t NOT_IMPLEMENTED = 501;
constexpr uint16_t BAD_GATEWAY = 502;
@ -52,7 +53,7 @@ struct StatusCodeInfo
std::string_view reason;
};
static constexpr std::array<StatusCodeInfo, 12> statusCodeInfos
static constexpr std::array<StatusCodeInfo, 13> statusCodeInfos
= {{{.code = StatusCode::OK, .reason = "OK"},
{.code = StatusCode::CREATED, .reason = "Created"},
{.code = StatusCode::NO_CONTENT, .reason = "No Content"},
@ -61,6 +62,7 @@ static constexpr std::array<StatusCodeInfo, 12> statusCodeInfos
{.code = StatusCode::FORBIDDEN, .reason = "Forbidden"},
{.code = StatusCode::NOT_FOUND, .reason = "Not Found"},
{.code = StatusCode::METHOD_NOT_ALLOWED, .reason = "Method Not Allowed"},
{.code = StatusCode::PAYLOAD_TOO_LARGE, .reason = "Payload Too Large"},
{.code = StatusCode::INTERNAL_SERVER_ERROR, .reason = "Internal Server Error"},
{.code = StatusCode::NOT_IMPLEMENTED, .reason = "Not Implemented"},
{.code = StatusCode::BAD_GATEWAY, .reason = "Bad Gateway"},

View File

@ -1,8 +1,9 @@
#include <webserv/http/HttpRequest.hpp>
#include "webserv/config/ServerConfig.hpp"
#include <webserv/config/ConfigManager.hpp> // for ConfigManager
#include <webserv/handler/URI.hpp> // for URI
#include <webserv/http/HttpConstants.hpp> // for CRLF, DOUBLE_CRLF
#include <webserv/http/HttpRequest.hpp>
#include <webserv/log/Log.hpp> // for Log, LOCATION
#include <webserv/utils/utils.hpp> // for stoul
@ -32,12 +33,10 @@ void HttpRequest::setState(State state)
{
if (state == State::Complete)
{
// TODO: segfault if server does not exist
std::string hostHeader = getHeaders().getHost().value_or("");
ServerConfig *serverConfig = ConfigManager::getInstance().getMatchingServerConfig(hostHeader);
if (hostHeader.empty() || serverConfig == nullptr)
ServerConfig *serverConfig = getServerConfig();
if (serverConfig == nullptr)
{
Log::error("No matching server config found for host: " + hostHeader);
Log::error("No matching server config found for host");
state_ = State::ParseError;
return;
}
@ -246,6 +245,16 @@ void HttpRequest::reset()
// contentLength_ = 0;
}
ServerConfig *HttpRequest::getServerConfig() const
{
std::string hostHeader = getHeaders().getHost().value_or("");
if (hostHeader.empty())
{
return nullptr;
}
return ConfigManager::getInstance().getMatchingServerConfig(hostHeader);
}
const URI &HttpRequest::getUri() const noexcept
{
return *uri_;

View File

@ -56,6 +56,8 @@ class HttpRequest
void parseBuffer();
void parseContentLength();
ServerConfig * getServerConfig() const;
Client *client_;
State state_ = State::RequestLine;

View File

@ -0,0 +1,67 @@
#include "webserv/config/AConfig.hpp"
#include <webserv/http/RequestValidator.hpp>
#include <cstddef>
#include <string>
#include <vector>
RequestValidator::RequestValidator(const AConfig *config, const HttpRequest *request) : config(config), request(request)
{
}
std::optional<RequestValidator::ValidationError> RequestValidator::validate() const
{
if (auto error = validateContentLength())
{
return error;
}
if (auto error = validateMethod())
{
return error;
}
return std::nullopt; // No validation errors
}
std::optional<RequestValidator::ValidationError> RequestValidator::validateContentLength() const
{
size_t bodySize = request->getBody().size();
size_t maxBodySize = config->get<size_t>("client_max_body_size").value_or(1024 * 1024);
if (bodySize > maxBodySize) // exceed server limit
{
return ValidationError{413, "Payload Too Large"};
}
// If Content-Length header is present, validate it
if (request->getHeaders().getContentLength())
{
size_t contentLength = *request->getHeaders().getContentLength();
if (contentLength != bodySize)
{
return ValidationError{400, "Bad Request: Content-Length does not match body size"};
}
}
return std::nullopt; // Content-Length is valid
}
std::optional<RequestValidator::ValidationError> RequestValidator::validateMethod() const
{
auto allowedMethodsOpt = config->get<std::vector<std::string>>("allowed_methods");
std::vector<std::string> allowedMethods;
if (allowedMethodsOpt.has_value())
{
allowedMethods = std::move(*allowedMethodsOpt);
}
else
{
allowedMethods = {"GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS"};
}
for (const std::string &method : allowedMethods)
{
if (method == request->getMethod())
{
return std::nullopt; // Method is allowed
}
}
return ValidationError{405, "Method Not Allowed"};
}

View File

@ -1,5 +1,6 @@
#pragma once
#include "webserv/config/AConfig.hpp"
#include <webserv/config/ServerConfig.hpp>
#include <webserv/http/HttpRequest.hpp>
@ -14,20 +15,14 @@ class RequestValidator
int statusCode;
std::string message;
};
RequestValidator() = delete;
RequestValidator(const ServerConfig &config, const HttpRequest &request);
RequestValidator(const RequestValidator &other) = delete;
RequestValidator(RequestValidator &&other) = delete;
RequestValidator &operator=(const RequestValidator &other) = delete;
RequestValidator &operator=(RequestValidator &&other) = delete;
~RequestValidator() = default;
RequestValidator(const AConfig *config, const HttpRequest *request);
[[nodiscard]] std::optional<ValidationError> validate() const;
private:
ServerConfig &config;
HttpRequest &request;
};
const AConfig *config;
const HttpRequest *request;
[[nodiscard]] std::optional<ValidationError> validateContentLength() const;
[[nodiscard]] std::optional<ValidationError> validateMethod() const;
};

View File

@ -1,4 +1,4 @@
#include <webserv/router/Router.hpp> // for Router
#include "webserv/handler/ErrorHandler.hpp"
#include <webserv/client/Client.hpp> // for Client
#include <webserv/config/AConfig.hpp> // for AConfig
@ -10,6 +10,7 @@
#include <webserv/handler/URI.hpp> // for URI
#include <webserv/http/HttpRequest.hpp> // for HttpRequest
#include <webserv/log/Log.hpp> // for Log, LOCATION
#include <webserv/router/Router.hpp> // for Router
#include <exception> // for exception
#include <format> // for vector
@ -26,16 +27,16 @@ Router::Router(Client *client) : client_(client)
Log::trace(LOCATION);
}
bool Router::isMethodSupported(const std::string &method, const AConfig &config) noexcept
{
const ADirective *allowedMethods = config.getDirective("allowed_methods");
if (allowedMethods == nullptr || !allowedMethods->getValue().try_get<std::vector<std::string>>().has_value())
{
return true;
}
auto methods = allowedMethods->getValue().get<std::vector<std::string>>();
return std::ranges::find(methods, method) != methods.end();
}
// bool Router::isMethodSupported(const std::string &method, const AConfig &config) noexcept
// {
// const ADirective *allowedMethods = config.getDirective("allowed_methods");
// if (allowedMethods == nullptr || !allowedMethods->getValue().try_get<std::vector<std::string>>().has_value())
// {
// return true;
// }
// auto methods = allowedMethods->getValue().get<std::vector<std::string>>();
// return std::ranges::find(methods, method) != methods.end();
// }
std::unique_ptr<AHandler> Router::handleRequest()
{
@ -50,17 +51,25 @@ std::unique_ptr<AHandler> Router::handleRequest()
}
HttpResponse &response = client_->getHttpResponse();
const std::string &target = request.getTarget();
static_cast<void>(target); // Suppress unused variable warning
const std::string &method = request.getMethod();
// const std::string &target = request.getTarget();
// static_cast<void>(target); // Suppress unused variable warning
// const std::string &method = request.getMethod();
const AConfig *config = request.getUri().getConfig();
if (!isMethodSupported(method, *config))
auto validator = std::make_unique<RequestValidator>(config, &request);
auto error = validator->validate();
if (error.has_value())
{
Log::warning("Request validation failed: " + error->message);
ErrorHandler::createErrorResponse(error->statusCode, response, config);
return nullptr;
// return ErrorHandler::getErrorResponse(405, config);
}
// if (!isMethodSupported(method, *config))
// {
// return nullptr;
// // return ErrorHandler::getErrorResponse(405, config);
// }
if (request.getUri().isCgi())
{
try

View File

@ -1,5 +1,6 @@
#pragma once
#include "webserv/http/RequestValidator.hpp"
#include <webserv/config/AConfig.hpp> // for AConfig
#include <webserv/config/LocationConfig.hpp>
#include <webserv/handler/AHandler.hpp> // for AHandler
@ -7,7 +8,6 @@
#include <webserv/http/HttpResponse.hpp> // for HttpResponse
#include <memory> // for unique_ptr
#include <string> // for string
class LocationConfig;
class ServerConfig;
@ -21,5 +21,6 @@ class Router
private:
Client *client_;
[[nodiscard]] bool isMethodSupported(const std::string &method, const AConfig &config) noexcept;
std::unique_ptr<RequestValidator> requestValidator_;
// [[nodiscard]] bool isMethodSupported(const std::string &method, const AConfig &config) noexcept;
};