feat: request validation
This commit is contained in:
parent
d98fb98c52
commit
4cf3eac4d1
@ -1,5 +1,4 @@
|
|||||||
#include <webserv/client/Client.hpp>
|
#include <webserv/client/Client.hpp>
|
||||||
|
|
||||||
#include <webserv/handler/CgiHandler.hpp> // for CgiHandler
|
#include <webserv/handler/CgiHandler.hpp> // for CgiHandler
|
||||||
#include <webserv/handler/ErrorHandler.hpp> // for ErrorHandler
|
#include <webserv/handler/ErrorHandler.hpp> // for ErrorHandler
|
||||||
#include <webserv/http/HttpHeaders.hpp> // for HttpHeaders
|
#include <webserv/http/HttpHeaders.hpp> // for HttpHeaders
|
||||||
@ -91,15 +90,23 @@ void Client::request()
|
|||||||
{"body", httpRequest_->getBody()},
|
{"body", httpRequest_->getBody()},
|
||||||
{"state", std::to_string(static_cast<uint8_t>(httpRequest_->getState()))},
|
{"state", std::to_string(static_cast<uint8_t>(httpRequest_->getState()))},
|
||||||
});
|
});
|
||||||
// server_.responseReady(client_socket_->getFd());
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
// Thoughts: if a handler isn't returned, this could because of the error handler already setting up the response
|
||||||
|
// so, maybe we don't need to throw a 500 when no handler. Because that would override the actual error response.
|
||||||
|
// How about the router, or a handler, throws an exception if something goes wrong, and we catch it here to make a 500 response?
|
||||||
handler_ = router_->handleRequest();
|
handler_ = router_->handleRequest();
|
||||||
if (handler_ != nullptr)
|
if (handler_ != nullptr)
|
||||||
{
|
{
|
||||||
handler_->handle();
|
handler_->handle();
|
||||||
}
|
}
|
||||||
else
|
}
|
||||||
|
catch (const std::exception &e)
|
||||||
{
|
{
|
||||||
|
Log::error("Exception during request handling: " + std::string(e.what()));
|
||||||
ErrorHandler::createErrorResponse(500, *httpResponse_);
|
ErrorHandler::createErrorResponse(500, *httpResponse_);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@ -112,23 +119,6 @@ void Client::request()
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// //
|
|
||||||
// void Client::setCgiSockets(CgiSocket *cgiStdIn, CgiSocket *cgiStdOut)
|
|
||||||
// {
|
|
||||||
// // server_.add(*cgiStdIn, EPOLLOUT, this); // write
|
|
||||||
// // server_.add(*cgiStdOut, EPOLLIN, this); // read
|
|
||||||
|
|
||||||
// sockets_[cgiStdIn->getFd()] = cgiStdIn;
|
|
||||||
// sockets_[cgiStdOut->getFd()] = cgiStdOut;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// void Client::removeCgiSocket(CgiSocket *cgiSocket)
|
|
||||||
// {
|
|
||||||
// server_.remove(*cgiSocket); // write
|
|
||||||
|
|
||||||
// sockets_.erase(cgiSocket->getFd());
|
|
||||||
// }
|
|
||||||
|
|
||||||
void Client::addSocket(ASocket *socket)
|
void Client::addSocket(ASocket *socket)
|
||||||
{
|
{
|
||||||
server_.add(*socket, this);
|
server_.add(*socket, this);
|
||||||
|
|||||||
@ -40,6 +40,7 @@ constexpr uint16_t UNAUTHORIZED = 401;
|
|||||||
constexpr uint16_t FORBIDDEN = 403;
|
constexpr uint16_t FORBIDDEN = 403;
|
||||||
constexpr uint16_t NOT_FOUND = 404;
|
constexpr uint16_t NOT_FOUND = 404;
|
||||||
constexpr uint16_t METHOD_NOT_ALLOWED = 405;
|
constexpr uint16_t METHOD_NOT_ALLOWED = 405;
|
||||||
|
constexpr uint16_t PAYLOAD_TOO_LARGE = 413;
|
||||||
constexpr uint16_t INTERNAL_SERVER_ERROR = 500;
|
constexpr uint16_t INTERNAL_SERVER_ERROR = 500;
|
||||||
constexpr uint16_t NOT_IMPLEMENTED = 501;
|
constexpr uint16_t NOT_IMPLEMENTED = 501;
|
||||||
constexpr uint16_t BAD_GATEWAY = 502;
|
constexpr uint16_t BAD_GATEWAY = 502;
|
||||||
@ -52,7 +53,7 @@ struct StatusCodeInfo
|
|||||||
std::string_view reason;
|
std::string_view reason;
|
||||||
};
|
};
|
||||||
|
|
||||||
static constexpr std::array<StatusCodeInfo, 12> statusCodeInfos
|
static constexpr std::array<StatusCodeInfo, 13> statusCodeInfos
|
||||||
= {{{.code = StatusCode::OK, .reason = "OK"},
|
= {{{.code = StatusCode::OK, .reason = "OK"},
|
||||||
{.code = StatusCode::CREATED, .reason = "Created"},
|
{.code = StatusCode::CREATED, .reason = "Created"},
|
||||||
{.code = StatusCode::NO_CONTENT, .reason = "No Content"},
|
{.code = StatusCode::NO_CONTENT, .reason = "No Content"},
|
||||||
@ -61,6 +62,7 @@ static constexpr std::array<StatusCodeInfo, 12> statusCodeInfos
|
|||||||
{.code = StatusCode::FORBIDDEN, .reason = "Forbidden"},
|
{.code = StatusCode::FORBIDDEN, .reason = "Forbidden"},
|
||||||
{.code = StatusCode::NOT_FOUND, .reason = "Not Found"},
|
{.code = StatusCode::NOT_FOUND, .reason = "Not Found"},
|
||||||
{.code = StatusCode::METHOD_NOT_ALLOWED, .reason = "Method Not Allowed"},
|
{.code = StatusCode::METHOD_NOT_ALLOWED, .reason = "Method Not Allowed"},
|
||||||
|
{.code = StatusCode::PAYLOAD_TOO_LARGE, .reason = "Payload Too Large"},
|
||||||
{.code = StatusCode::INTERNAL_SERVER_ERROR, .reason = "Internal Server Error"},
|
{.code = StatusCode::INTERNAL_SERVER_ERROR, .reason = "Internal Server Error"},
|
||||||
{.code = StatusCode::NOT_IMPLEMENTED, .reason = "Not Implemented"},
|
{.code = StatusCode::NOT_IMPLEMENTED, .reason = "Not Implemented"},
|
||||||
{.code = StatusCode::BAD_GATEWAY, .reason = "Bad Gateway"},
|
{.code = StatusCode::BAD_GATEWAY, .reason = "Bad Gateway"},
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
#include <webserv/http/HttpRequest.hpp>
|
#include "webserv/config/ServerConfig.hpp"
|
||||||
|
|
||||||
#include <webserv/config/ConfigManager.hpp> // for ConfigManager
|
#include <webserv/config/ConfigManager.hpp> // for ConfigManager
|
||||||
#include <webserv/handler/URI.hpp> // for URI
|
#include <webserv/handler/URI.hpp> // for URI
|
||||||
#include <webserv/http/HttpConstants.hpp> // for CRLF, DOUBLE_CRLF
|
#include <webserv/http/HttpConstants.hpp> // for CRLF, DOUBLE_CRLF
|
||||||
|
#include <webserv/http/HttpRequest.hpp>
|
||||||
#include <webserv/log/Log.hpp> // for Log, LOCATION
|
#include <webserv/log/Log.hpp> // for Log, LOCATION
|
||||||
#include <webserv/utils/utils.hpp> // for stoul
|
#include <webserv/utils/utils.hpp> // for stoul
|
||||||
|
|
||||||
@ -32,12 +33,10 @@ void HttpRequest::setState(State state)
|
|||||||
{
|
{
|
||||||
if (state == State::Complete)
|
if (state == State::Complete)
|
||||||
{
|
{
|
||||||
// TODO: segfault if server does not exist
|
ServerConfig *serverConfig = getServerConfig();
|
||||||
std::string hostHeader = getHeaders().getHost().value_or("");
|
if (serverConfig == nullptr)
|
||||||
ServerConfig *serverConfig = ConfigManager::getInstance().getMatchingServerConfig(hostHeader);
|
|
||||||
if (hostHeader.empty() || serverConfig == nullptr)
|
|
||||||
{
|
{
|
||||||
Log::error("No matching server config found for host: " + hostHeader);
|
Log::error("No matching server config found for host");
|
||||||
state_ = State::ParseError;
|
state_ = State::ParseError;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -246,6 +245,16 @@ void HttpRequest::reset()
|
|||||||
// contentLength_ = 0;
|
// contentLength_ = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ServerConfig *HttpRequest::getServerConfig() const
|
||||||
|
{
|
||||||
|
std::string hostHeader = getHeaders().getHost().value_or("");
|
||||||
|
if (hostHeader.empty())
|
||||||
|
{
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return ConfigManager::getInstance().getMatchingServerConfig(hostHeader);
|
||||||
|
}
|
||||||
|
|
||||||
const URI &HttpRequest::getUri() const noexcept
|
const URI &HttpRequest::getUri() const noexcept
|
||||||
{
|
{
|
||||||
return *uri_;
|
return *uri_;
|
||||||
|
|||||||
@ -56,6 +56,8 @@ class HttpRequest
|
|||||||
void parseBuffer();
|
void parseBuffer();
|
||||||
void parseContentLength();
|
void parseContentLength();
|
||||||
|
|
||||||
|
ServerConfig * getServerConfig() const;
|
||||||
|
|
||||||
Client *client_;
|
Client *client_;
|
||||||
|
|
||||||
State state_ = State::RequestLine;
|
State state_ = State::RequestLine;
|
||||||
|
|||||||
67
webserv/http/RequestValidator.cpp
Normal file
67
webserv/http/RequestValidator.cpp
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
#include "webserv/config/AConfig.hpp"
|
||||||
|
|
||||||
|
#include <webserv/http/RequestValidator.hpp>
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
RequestValidator::RequestValidator(const AConfig *config, const HttpRequest *request) : config(config), request(request)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<RequestValidator::ValidationError> RequestValidator::validate() const
|
||||||
|
{
|
||||||
|
if (auto error = validateContentLength())
|
||||||
|
{
|
||||||
|
return error;
|
||||||
|
}
|
||||||
|
if (auto error = validateMethod())
|
||||||
|
{
|
||||||
|
return error;
|
||||||
|
}
|
||||||
|
return std::nullopt; // No validation errors
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<RequestValidator::ValidationError> RequestValidator::validateContentLength() const
|
||||||
|
{
|
||||||
|
size_t bodySize = request->getBody().size();
|
||||||
|
size_t maxBodySize = config->get<size_t>("client_max_body_size").value_or(1024 * 1024);
|
||||||
|
if (bodySize > maxBodySize) // exceed server limit
|
||||||
|
{
|
||||||
|
return ValidationError{413, "Payload Too Large"};
|
||||||
|
}
|
||||||
|
// If Content-Length header is present, validate it
|
||||||
|
if (request->getHeaders().getContentLength())
|
||||||
|
{
|
||||||
|
size_t contentLength = *request->getHeaders().getContentLength();
|
||||||
|
if (contentLength != bodySize)
|
||||||
|
{
|
||||||
|
return ValidationError{400, "Bad Request: Content-Length does not match body size"};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::nullopt; // Content-Length is valid
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<RequestValidator::ValidationError> RequestValidator::validateMethod() const
|
||||||
|
{
|
||||||
|
auto allowedMethodsOpt = config->get<std::vector<std::string>>("allowed_methods");
|
||||||
|
std::vector<std::string> allowedMethods;
|
||||||
|
|
||||||
|
if (allowedMethodsOpt.has_value())
|
||||||
|
{
|
||||||
|
allowedMethods = std::move(*allowedMethodsOpt);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
allowedMethods = {"GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS"};
|
||||||
|
}
|
||||||
|
for (const std::string &method : allowedMethods)
|
||||||
|
{
|
||||||
|
if (method == request->getMethod())
|
||||||
|
{
|
||||||
|
return std::nullopt; // Method is allowed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ValidationError{405, "Method Not Allowed"};
|
||||||
|
}
|
||||||
@ -1,5 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "webserv/config/AConfig.hpp"
|
||||||
#include <webserv/config/ServerConfig.hpp>
|
#include <webserv/config/ServerConfig.hpp>
|
||||||
#include <webserv/http/HttpRequest.hpp>
|
#include <webserv/http/HttpRequest.hpp>
|
||||||
|
|
||||||
@ -14,20 +15,14 @@ class RequestValidator
|
|||||||
int statusCode;
|
int statusCode;
|
||||||
std::string message;
|
std::string message;
|
||||||
};
|
};
|
||||||
|
RequestValidator(const AConfig *config, const HttpRequest *request);
|
||||||
RequestValidator() = delete;
|
|
||||||
RequestValidator(const ServerConfig &config, const HttpRequest &request);
|
|
||||||
RequestValidator(const RequestValidator &other) = delete;
|
|
||||||
RequestValidator(RequestValidator &&other) = delete;
|
|
||||||
|
|
||||||
RequestValidator &operator=(const RequestValidator &other) = delete;
|
|
||||||
RequestValidator &operator=(RequestValidator &&other) = delete;
|
|
||||||
|
|
||||||
~RequestValidator() = default;
|
|
||||||
|
|
||||||
[[nodiscard]] std::optional<ValidationError> validate() const;
|
[[nodiscard]] std::optional<ValidationError> validate() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ServerConfig &config;
|
const AConfig *config;
|
||||||
HttpRequest &request;
|
const HttpRequest *request;
|
||||||
};
|
|
||||||
|
[[nodiscard]] std::optional<ValidationError> validateContentLength() const;
|
||||||
|
[[nodiscard]] std::optional<ValidationError> validateMethod() const;
|
||||||
|
};
|
||||||
@ -1,4 +1,4 @@
|
|||||||
#include <webserv/router/Router.hpp> // for Router
|
#include "webserv/handler/ErrorHandler.hpp"
|
||||||
|
|
||||||
#include <webserv/client/Client.hpp> // for Client
|
#include <webserv/client/Client.hpp> // for Client
|
||||||
#include <webserv/config/AConfig.hpp> // for AConfig
|
#include <webserv/config/AConfig.hpp> // for AConfig
|
||||||
@ -10,6 +10,7 @@
|
|||||||
#include <webserv/handler/URI.hpp> // for URI
|
#include <webserv/handler/URI.hpp> // for URI
|
||||||
#include <webserv/http/HttpRequest.hpp> // for HttpRequest
|
#include <webserv/http/HttpRequest.hpp> // for HttpRequest
|
||||||
#include <webserv/log/Log.hpp> // for Log, LOCATION
|
#include <webserv/log/Log.hpp> // for Log, LOCATION
|
||||||
|
#include <webserv/router/Router.hpp> // for Router
|
||||||
|
|
||||||
#include <exception> // for exception
|
#include <exception> // for exception
|
||||||
#include <format> // for vector
|
#include <format> // for vector
|
||||||
@ -26,16 +27,16 @@ Router::Router(Client *client) : client_(client)
|
|||||||
Log::trace(LOCATION);
|
Log::trace(LOCATION);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Router::isMethodSupported(const std::string &method, const AConfig &config) noexcept
|
// bool Router::isMethodSupported(const std::string &method, const AConfig &config) noexcept
|
||||||
{
|
// {
|
||||||
const ADirective *allowedMethods = config.getDirective("allowed_methods");
|
// const ADirective *allowedMethods = config.getDirective("allowed_methods");
|
||||||
if (allowedMethods == nullptr || !allowedMethods->getValue().try_get<std::vector<std::string>>().has_value())
|
// if (allowedMethods == nullptr || !allowedMethods->getValue().try_get<std::vector<std::string>>().has_value())
|
||||||
{
|
// {
|
||||||
return true;
|
// return true;
|
||||||
}
|
// }
|
||||||
auto methods = allowedMethods->getValue().get<std::vector<std::string>>();
|
// auto methods = allowedMethods->getValue().get<std::vector<std::string>>();
|
||||||
return std::ranges::find(methods, method) != methods.end();
|
// return std::ranges::find(methods, method) != methods.end();
|
||||||
}
|
// }
|
||||||
|
|
||||||
std::unique_ptr<AHandler> Router::handleRequest()
|
std::unique_ptr<AHandler> Router::handleRequest()
|
||||||
{
|
{
|
||||||
@ -50,17 +51,25 @@ std::unique_ptr<AHandler> Router::handleRequest()
|
|||||||
}
|
}
|
||||||
HttpResponse &response = client_->getHttpResponse();
|
HttpResponse &response = client_->getHttpResponse();
|
||||||
|
|
||||||
const std::string &target = request.getTarget();
|
// const std::string &target = request.getTarget();
|
||||||
static_cast<void>(target); // Suppress unused variable warning
|
// static_cast<void>(target); // Suppress unused variable warning
|
||||||
const std::string &method = request.getMethod();
|
// const std::string &method = request.getMethod();
|
||||||
|
|
||||||
const AConfig *config = request.getUri().getConfig();
|
const AConfig *config = request.getUri().getConfig();
|
||||||
|
auto validator = std::make_unique<RequestValidator>(config, &request);
|
||||||
if (!isMethodSupported(method, *config))
|
auto error = validator->validate();
|
||||||
|
if (error.has_value())
|
||||||
{
|
{
|
||||||
|
Log::warning("Request validation failed: " + error->message);
|
||||||
|
ErrorHandler::createErrorResponse(error->statusCode, response, config);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
// return ErrorHandler::getErrorResponse(405, config);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if (!isMethodSupported(method, *config))
|
||||||
|
// {
|
||||||
|
// return nullptr;
|
||||||
|
// // return ErrorHandler::getErrorResponse(405, config);
|
||||||
|
// }
|
||||||
if (request.getUri().isCgi())
|
if (request.getUri().isCgi())
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "webserv/http/RequestValidator.hpp"
|
||||||
#include <webserv/config/AConfig.hpp> // for AConfig
|
#include <webserv/config/AConfig.hpp> // for AConfig
|
||||||
#include <webserv/config/LocationConfig.hpp>
|
#include <webserv/config/LocationConfig.hpp>
|
||||||
#include <webserv/handler/AHandler.hpp> // for AHandler
|
#include <webserv/handler/AHandler.hpp> // for AHandler
|
||||||
@ -7,7 +8,6 @@
|
|||||||
#include <webserv/http/HttpResponse.hpp> // for HttpResponse
|
#include <webserv/http/HttpResponse.hpp> // for HttpResponse
|
||||||
|
|
||||||
#include <memory> // for unique_ptr
|
#include <memory> // for unique_ptr
|
||||||
#include <string> // for string
|
|
||||||
|
|
||||||
class LocationConfig;
|
class LocationConfig;
|
||||||
class ServerConfig;
|
class ServerConfig;
|
||||||
@ -21,5 +21,6 @@ class Router
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
Client *client_;
|
Client *client_;
|
||||||
[[nodiscard]] bool isMethodSupported(const std::string &method, const AConfig &config) noexcept;
|
std::unique_ptr<RequestValidator> requestValidator_;
|
||||||
|
// [[nodiscard]] bool isMethodSupported(const std::string &method, const AConfig &config) noexcept;
|
||||||
};
|
};
|
||||||
Loading…
Reference in New Issue
Block a user