feat: implement RequiredDirectivesRule for config validation

This commit is contained in:
whaffman 2025-10-08 17:23:12 +02:00
parent a0b85774b1
commit fb84cd09fa
8 changed files with 132 additions and 27 deletions

View File

@ -13,12 +13,6 @@ class DirectiveFactory
public:
static std::unique_ptr<ADirective> createDirective(const std::string &line);
private:
using CreatorFunc = std::function<std::unique_ptr<ADirective>(const std::string &, const std::string &arg)>;
static const std::unordered_map<std::string_view, CreatorFunc> &getFactories();
static std::unique_ptr<ADirective> create(std::string_view type, const std::string &name, const std::string &arg);
struct DirectiveInfo
{
std::string_view name;
@ -26,21 +20,31 @@ class DirectiveFactory
std::string_view context;
};
constexpr static std::array<DirectiveInfo, 15> supportedDirectives = {{
{.name = "listen", .type = "IntDirective", .context = "SL"},
{.name = "host", .type = "StringDirective", .context = "SL"},
{.name = "server_name", .type = "StringDirective", .context = "SL"},
{.name = "root", .type = "StringDirective", .context = "SL"},
{.name = "index", .type = "VectorDirective", .context = "SL"},
{.name = "error_page", .type = "IntStringDirective", .context = "SL"},
{.name = "client_max_body_size", .type = "SizeDirective", .context = "SL"},
{.name = "autoindex", .type = "BoolDirective", .context = "L"},
{.name = "allowed_methods", .type = "VectorDirective", .context = "L"},
{.name = "cgi_enabled", .type = "BoolDirective", .context = "L"},
{.name = "cgi_ext", .type = "VectorDirective", .context = "L"},
{.name = "cgi_timeout", .type = "IntDirective", .context = "L"},
{.name = "upload_enabled", .type = "BoolDirective", .context = "L"},
{.name = "upload_store", .type = "StringDirective", .context = "L"},
{.name = "redirect", .type = "VectorDirective", .context = "L"},
{.name = "listen", .type = "IntDirective", .context = "S"},
{.name = "host", .type = "StringDirective", .context = "S"},
{.name = "server_name", .type = "StringDirective", .context = "S"},
{.name = "root", .type = "StringDirective", .context = "Sl"},
{.name = "index", .type = "VectorDirective", .context = "sl"},
{.name = "error_page", .type = "IntStringDirective", .context = "gsl"},
{.name = "client_max_body_size", .type = "SizeDirective", .context = "gsl"},
{.name = "autoindex", .type = "BoolDirective", .context = "gsl"},
{.name = "allowed_methods", .type = "VectorDirective", .context = "gsl"},
{.name = "cgi_enabled", .type = "BoolDirective", .context = "gsl"},
{.name = "cgi_ext", .type = "VectorDirective", .context = "gsl"},
{.name = "cgi_timeout", .type = "IntDirective", .context = "gsl"},
{.name = "upload_enabled", .type = "BoolDirective", .context = "gsl"},
{.name = "upload_store", .type = "StringDirective", .context = "gsl"},
{.name = "redirect", .type = "VectorDirective", .context = "l"},
}};
private:
using CreatorFunc = std::function<std::unique_ptr<ADirective>(const std::string &, const std::string &arg)>;
static const std::unordered_map<std::string_view, CreatorFunc> &getFactories();
static std::unique_ptr<ADirective> create(std::string_view type, const std::string &name, const std::string &arg);
};

View File

@ -21,6 +21,7 @@ ConfigValidator::ConfigValidator(const GlobalConfig *config) : engine_(std::make
engine_->addStructuralRule(std::make_unique<MinimumServerBlocksRule>(1));
engine_->addStructuralRule(std::make_unique<RequiredLocationBlocksRule>(1));
engine_->addStructuralRule(std::make_unique<UniqueServerNamesRule>());
engine_->addStructuralRule(std::make_unique<RequiredDirectivesRule>());
/*Global Directive Rules*/

View File

@ -1,7 +1,6 @@
#include <webserv/config/validation/directive_rules/AValidationRule.hpp>
#include <webserv/config/AConfig.hpp> // for AConfig
#include <webserv/config/validation/ValidationResult.hpp> // for ValidationResult
#include <webserv/config/validation/directive_rules/AValidationRule.hpp>
#include <webserv/log/Log.hpp> // for LOCATION, Log
#include <utility> // for move
@ -19,7 +18,7 @@ ValidationResult AValidationRule::validate(const AConfig *config, const std::str
return ValidationResult::error("Invalid config or directive name");
}
if (!config->hasDirective(directiveName))
if (!config->has(directiveName))
{
if (requiresValue_)
{

View File

@ -2,4 +2,3 @@
#include <webserv/config/validation/directive_rules/AllowedValuesRule.hpp>
#include <webserv/config/validation/directive_rules/PortValidationRule.hpp>
#include <webserv/config/validation/directive_rules/RequiredDirectiveRule.hpp>

View File

@ -0,0 +1,72 @@
#include <webserv/config/GlobalConfig.hpp>
#include <webserv/config/LocationConfig.hpp>
#include <webserv/config/ServerConfig.hpp>
#include <webserv/config/directive/DirectiveFactory.hpp>
#include <webserv/config/validation/structural_rules/AStructuralValidationRule.hpp>
#include <webserv/config/validation/structural_rules/RequiredDirectivesRule.hpp>
#include <webserv/log/Log.hpp>
#include <webserv/utils/utils.hpp>
#include <algorithm>
#include <cctype>
RequiredDirectivesRule::RequiredDirectivesRule()
: AStructuralValidationRule("RequiredDirectivesRule", "Ensures required directives are present in each context")
{
}
ValidationResult validateUniversal(const AConfig *config, std::string configType)
{
std::vector<std::string> missingDirectives;
std::vector<std::string> prohibitedDirectives;
for (const auto &directive : DirectiveFactory::supportedDirectives)
{
if (directive.context.find(std::toupper(configType[0])) != std::string::npos &&
!config->owns(std::string(directive.name)))
{
missingDirectives.emplace_back(directive.name);
}
if ((directive.context.find(std::toupper(configType[0])) == std::string::npos &&
directive.context.find(std::tolower(configType[0])) == std::string::npos) &&
config->owns(std::string(directive.name)))
{
prohibitedDirectives.emplace_back(directive.name);
}
}
if (missingDirectives.empty() && prohibitedDirectives.empty())
{
return ValidationResult::success();
}
std::string result;
if (!missingDirectives.empty())
{
result += "Missing " + configType + " directive: ";
result += utils::implode(missingDirectives, ", ");
}
if (!prohibitedDirectives.empty())
{
result += "Prohibited " + configType + " directive: ";
result += utils::implode(prohibitedDirectives, ", ");
}
return ValidationResult::error(result);
}
ValidationResult RequiredDirectivesRule::validateGlobal(const GlobalConfig *config) const
{
return validateUniversal(config, "global");
}
ValidationResult RequiredDirectivesRule::validateServer(const ServerConfig *config) const
{
return validateUniversal(config, "server");
}
ValidationResult RequiredDirectivesRule::validateLocation(const LocationConfig *config) const
{
return validateUniversal(config, "location");
}

View File

@ -0,0 +1,26 @@
#pragma once
#include <webserv/config/validation/structural_rules/AStructuralValidationRule.hpp>
#include <cstddef>
class GlobalConfig;
class ServerConfig;
class LocationConfig;
class RequiredDirectivesRule : public AStructuralValidationRule
{
public:
explicit RequiredDirectivesRule();
~RequiredDirectivesRule() override = default;
RequiredDirectivesRule(const RequiredDirectivesRule &other) = delete;
RequiredDirectivesRule &operator=(const RequiredDirectivesRule &other) = delete;
RequiredDirectivesRule(RequiredDirectivesRule &&other) noexcept = delete;
RequiredDirectivesRule &operator=(RequiredDirectivesRule &&other) noexcept = delete;
[[nodiscard]] ValidationResult validateGlobal(const GlobalConfig *config) const override;
[[nodiscard]] ValidationResult validateServer(const ServerConfig *config) const override;
[[nodiscard]] ValidationResult validateLocation(const LocationConfig *config) const override;
};

View File

@ -7,3 +7,4 @@
#include <webserv/config/validation/structural_rules/MinimumServerBlocksRule.hpp>
#include <webserv/config/validation/structural_rules/RequiredLocationBlocksRule.hpp>
#include <webserv/config/validation/structural_rules/UniqueServerNamesRule.hpp>
#include <webserv/config/validation/structural_rules/RequiredDirectivesRule.hpp>

View File

@ -8,6 +8,9 @@
#include <string> // for allocator, basic_string, char_traits, operator+, string
#include <vector> // for vector
int main(int argc, char **argv)
{
if (argc < 2)
@ -16,8 +19,8 @@ int main(int argc, char **argv)
return 1;
}
Log::setFileChannel("webserv.log", std::ios_base::trunc, Log::Level::Info);
Log::setStdoutChannel(Log::Level::Info);
Log::setFileChannel("webserv.log", std::ios_base::trunc);
Log::setStdoutChannel();
Log::info("\n======================\nStarting webserv...\n======================\n");
ConfigManager &configManager = ConfigManager::getInstance();