refactor(Client, Server, ASocket): update socket handling and state management
This commit is contained in:
parent
eedb5382ea
commit
cfb88745bb
@ -98,23 +98,33 @@ void Client::request()
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
void Client::setCgiSockets(CgiSocket *cgiStdIn, CgiSocket *cgiStdOut)
|
||||
{
|
||||
server_.add(*cgiStdIn, EPOLLOUT, this); // write
|
||||
server_.add(*cgiStdOut, EPOLLIN, this); // read
|
||||
// //
|
||||
// void Client::setCgiSockets(CgiSocket *cgiStdIn, CgiSocket *cgiStdOut)
|
||||
// {
|
||||
// // server_.add(*cgiStdIn, EPOLLOUT, this); // write
|
||||
// // server_.add(*cgiStdOut, EPOLLIN, this); // read
|
||||
|
||||
sockets_[cgiStdIn->getFd()] = cgiStdIn;
|
||||
sockets_[cgiStdOut->getFd()] = cgiStdOut;
|
||||
// sockets_[cgiStdIn->getFd()] = cgiStdIn;
|
||||
// sockets_[cgiStdOut->getFd()] = cgiStdOut;
|
||||
// }
|
||||
|
||||
// void Client::removeCgiSocket(CgiSocket *cgiSocket)
|
||||
// {
|
||||
// server_.remove(*cgiSocket); // write
|
||||
|
||||
// sockets_.erase(cgiSocket->getFd());
|
||||
// }
|
||||
|
||||
void Client::addSocket(ASocket *socket)
|
||||
{
|
||||
server_.add(*socket, this);
|
||||
sockets_[socket->getFd()] = socket;
|
||||
}
|
||||
|
||||
void Client::removeCgiSocket(CgiSocket *cgiSocket)
|
||||
void Client::removeSocket(ASocket *socket)
|
||||
{
|
||||
server_.remove(*cgiSocket); // write
|
||||
|
||||
sockets_.erase(cgiSocket->getFd());
|
||||
// sockets_[cgiStdIn->getFd()] = cgiStdIn;
|
||||
// sockets_[cgiStdOut->getFd()] = cgiStdOut;
|
||||
server_.remove(*socket);
|
||||
sockets_.erase(socket->getFd());
|
||||
}
|
||||
|
||||
void Client::poll() const
|
||||
@ -130,7 +140,8 @@ void Client::poll() const
|
||||
{
|
||||
Log::info("Response is ready to be sent to client, fd: " + std::to_string(clientSocket_->getFd()));
|
||||
clientSocket_->setCallback([this]() { respond(); });
|
||||
server_.writable(clientSocket_->getFd());
|
||||
// server_.writable(clientSocket_->getFd());
|
||||
clientSocket_->setIOState(ASocket::IOState::WRITE);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -43,8 +43,11 @@ class Client
|
||||
[[nodiscard]] ASocket &getSocket(int fd = -1) const;
|
||||
|
||||
// void setStatusCode(int code);
|
||||
void setCgiSockets(CgiSocket *cgiStdIn, CgiSocket *cgiStdOut);
|
||||
void removeCgiSocket(CgiSocket *cgiSocket);
|
||||
// void setCgiSockets(CgiSocket *cgiStdIn, CgiSocket *cgiStdOut);
|
||||
// void removeCgiSocket(CgiSocket *cgiSocket);
|
||||
|
||||
void addSocket(ASocket *socket);
|
||||
void removeSocket(ASocket *socket);
|
||||
|
||||
[[nodiscard]] HttpRequest &getHttpRequest() const noexcept;
|
||||
[[nodiscard]] HttpResponse &getHttpResponse() const noexcept;
|
||||
|
||||
@ -34,7 +34,7 @@ void CgiHandler::write()
|
||||
if (request_.getBody().empty())
|
||||
{
|
||||
Log::debug("No body to write to CGI stdin, fd: " + std::to_string(cgiStdIn_->getFd()));
|
||||
request_.getClient().removeCgiSocket(cgiStdIn_.get());
|
||||
request_.getClient().removeSocket(cgiStdIn_.get());
|
||||
cgiStdIn_ = nullptr;
|
||||
return;
|
||||
}
|
||||
@ -48,7 +48,7 @@ void CgiHandler::write()
|
||||
Log::debug("Wrote " + std::to_string(bytesWritten)
|
||||
+ " bytes to CGI stdin, fd: " + std::to_string(cgiStdIn_->getFd()));
|
||||
}
|
||||
request_.getClient().removeCgiSocket(cgiStdIn_.get());
|
||||
request_.getClient().removeSocket(cgiStdIn_.get());
|
||||
cgiStdIn_ = nullptr;
|
||||
}
|
||||
|
||||
@ -70,7 +70,7 @@ void CgiHandler::read()
|
||||
else if (bytesRead == 0)
|
||||
{
|
||||
Log::info("CGI process closed stdout, fd: " + std::to_string(cgiStdOut_->getFd()));
|
||||
request_.getClient().removeCgiSocket(cgiStdOut_.get());
|
||||
request_.getClient().removeSocket(cgiStdOut_.get());
|
||||
cgiStdOut_ = nullptr;
|
||||
parseCgiOutput();
|
||||
return;
|
||||
@ -92,9 +92,8 @@ void CgiHandler::setCgiSockets(std::unique_ptr<CgiSocket> cgiStdIn, std::unique_
|
||||
cgiStdOut_ = std::move(cgiStdOut);
|
||||
cgiStdIn_ = std::move(cgiStdIn);
|
||||
|
||||
request_.getClient().setCgiSockets(cgiStdIn_.get(), cgiStdOut_.get()); // write
|
||||
|
||||
// TODO add to handler
|
||||
request_.getClient().addSocket(cgiStdIn_.get());
|
||||
request_.getClient().addSocket(cgiStdOut_.get());
|
||||
}
|
||||
|
||||
void CgiHandler::wait() noexcept
|
||||
|
||||
@ -74,8 +74,8 @@ void CgiProcess::spawn()
|
||||
else
|
||||
{
|
||||
// Parent process
|
||||
auto cgiStdIn = std::make_unique<CgiSocket>(pipeStdin[1]);
|
||||
auto cgiStdOut = std::make_unique<CgiSocket>(pipeStdout[0]);
|
||||
auto cgiStdIn = std::make_unique<CgiSocket>(pipeStdin[1], ASocket::IOState::WRITE);
|
||||
auto cgiStdOut = std::make_unique<CgiSocket>(pipeStdout[0], ASocket::IOState::READ);
|
||||
close(pipeStdin[0]);
|
||||
close(pipeStdout[1]);
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
#include "webserv/utils/utils.hpp"
|
||||
|
||||
#include <webserv/client/Client.hpp> // for Client
|
||||
#include <webserv/config/ConfigManager.hpp> // for ConfigManager
|
||||
#include <webserv/config/ServerConfig.hpp> // for ServerConfig
|
||||
@ -59,7 +61,7 @@ Server::~Server()
|
||||
}
|
||||
}
|
||||
|
||||
void Server::add(const ASocket &socket, uint32_t events, Client *client)
|
||||
void Server::add(ASocket &socket, Client *client)
|
||||
{
|
||||
if (socket.getType() != ASocket::Type::SERVER_SOCKET && client == nullptr)
|
||||
{
|
||||
@ -69,7 +71,7 @@ void Server::add(const ASocket &socket, uint32_t events, Client *client)
|
||||
Log::trace(LOCATION);
|
||||
int fd = socket.getFd();
|
||||
struct epoll_event event{};
|
||||
event.events = events;
|
||||
event.events = utils::stateToEpoll(socket.getEvent());
|
||||
event.data.fd = fd;
|
||||
if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &event) == -1)
|
||||
{
|
||||
@ -78,9 +80,10 @@ void Server::add(const ASocket &socket, uint32_t events, Client *client)
|
||||
}
|
||||
Log::debug("Socket added to epoll, fd: " + std::to_string(fd));
|
||||
socketToClient_[fd] = client;
|
||||
sockets_.insert(&socket);
|
||||
}
|
||||
|
||||
void Server::remove(const ASocket &socket)
|
||||
void Server::remove(ASocket &socket)
|
||||
{
|
||||
Log::trace(LOCATION);
|
||||
int fd = socket.getFd();
|
||||
@ -90,6 +93,7 @@ void Server::remove(const ASocket &socket)
|
||||
throw std::runtime_error("epoll_ctl DEL failed");
|
||||
}
|
||||
socketToClient_.erase(fd);
|
||||
sockets_.erase(&socket);
|
||||
}
|
||||
|
||||
void Server::disconnect(const Client &client)
|
||||
@ -112,7 +116,7 @@ void Server::setupServerSocket(const ServerConfig &config)
|
||||
serverSocket->listen(SOMAXCONN);
|
||||
int server_fd = serverSocket->getFd();
|
||||
|
||||
add(*serverSocket, EPOLLIN);
|
||||
add(*serverSocket);
|
||||
|
||||
listeners_.push_back(std::move(serverSocket));
|
||||
listener_fds_.insert(server_fd);
|
||||
@ -131,7 +135,7 @@ void Server::handleConnection(struct epoll_event *event)
|
||||
std::unique_ptr<ClientSocket> clientSocket = listener.accept();
|
||||
|
||||
auto client = std::make_unique<Client>(std::move(clientSocket), *this);
|
||||
add(client->getSocket(), EPOLLIN, client.get());
|
||||
add(client->getSocket(), client.get());
|
||||
clients_.emplace_back(std::move(client));
|
||||
}
|
||||
|
||||
@ -182,6 +186,23 @@ void Server::writable(int client_fd) const
|
||||
}
|
||||
}
|
||||
|
||||
void Server::update(const ASocket &socket) const
|
||||
{
|
||||
Log::trace(LOCATION);
|
||||
|
||||
int socketFd = socket.getFd();
|
||||
uint32_t events = utils::stateToEpoll(socket.getEvent());
|
||||
Log::debug("Socket (" + std::to_string(socket.getFd()) + ") is being updated");
|
||||
struct epoll_event evt{};
|
||||
evt.events = events;
|
||||
evt.data.fd = socketFd;
|
||||
if (epoll_ctl(epoll_fd_, EPOLL_CTL_MOD, socketFd, &evt) == -1)
|
||||
{
|
||||
Log::error("epoll_ctl MOD failed for fd: " + std::to_string(socketFd));
|
||||
throw std::runtime_error("epoll_ctl MOD failed");
|
||||
}
|
||||
}
|
||||
|
||||
void Server::handleResponse(struct epoll_event *event) const
|
||||
{
|
||||
int socket_fd = event->data.fd;
|
||||
@ -254,6 +275,18 @@ void Server::pollClients() const
|
||||
}
|
||||
}
|
||||
|
||||
void Server::pollSockets()
|
||||
{
|
||||
for (auto *socket : sockets_)
|
||||
{
|
||||
if (socket->isDirty())
|
||||
{
|
||||
update(*socket);
|
||||
socket->processed();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Server::run()
|
||||
{
|
||||
Log::trace(LOCATION);
|
||||
@ -262,6 +295,7 @@ void Server::run()
|
||||
struct epoll_event events[MAX_EVENTS]; // NOLINT
|
||||
while (true)
|
||||
{
|
||||
pollSockets();
|
||||
pollClients();
|
||||
handleEpoll(events, MAX_EVENTS);
|
||||
}
|
||||
|
||||
@ -33,8 +33,9 @@ class Server
|
||||
~Server();
|
||||
|
||||
void run();
|
||||
void add(const ASocket &socket, uint32_t events, Client *client = nullptr);
|
||||
void remove(const ASocket &socket);
|
||||
void add(ASocket &socket, Client *client = nullptr);
|
||||
void remove(ASocket &socket);
|
||||
void update(const ASocket &socket) const;
|
||||
void disconnect(const Client &client);
|
||||
void writable(int client_fd) const;
|
||||
|
||||
@ -48,8 +49,10 @@ class Server
|
||||
std::set<int> listener_fds_;
|
||||
std::vector<std::unique_ptr<Client>> clients_;
|
||||
std::unordered_map<int, Client *> socketToClient_;
|
||||
std::set<ASocket *> sockets_;
|
||||
|
||||
void pollClients() const;
|
||||
void pollSockets();
|
||||
void handleEpoll(struct epoll_event *events, int max_events);
|
||||
|
||||
void handleEpollHangUp(struct epoll_event *event) const;
|
||||
|
||||
@ -10,7 +10,7 @@
|
||||
#include <sys/socket.h> // for recv, send
|
||||
#include <unistd.h> // for close
|
||||
|
||||
ASocket::ASocket(int fd) : fd_(fd)
|
||||
ASocket::ASocket(int fd, IOState event) : fd_(fd), ioState_(event)
|
||||
{
|
||||
Log::trace(LOCATION);
|
||||
if (fd_ == -1)
|
||||
@ -65,6 +65,34 @@ int ASocket::getFd() const noexcept
|
||||
return fd_;
|
||||
}
|
||||
|
||||
ASocket::IOState ASocket::getEvent() const noexcept
|
||||
{
|
||||
return ioState_;
|
||||
}
|
||||
|
||||
bool ASocket::isDirty() const noexcept
|
||||
{
|
||||
return dirty_;
|
||||
}
|
||||
|
||||
void ASocket::setIOState(IOState event)
|
||||
{
|
||||
if (event == ioState_)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
Log::debug("Processing state change for socket " + std::to_string(fd_));
|
||||
dirty_ = true;
|
||||
ioState_ = event;
|
||||
}
|
||||
|
||||
void ASocket::processed()
|
||||
{
|
||||
Log::debug("Socket " + std::to_string(fd_) + " processed");
|
||||
dirty_ = false;
|
||||
}
|
||||
|
||||
void ASocket::setFd(int fd)
|
||||
{
|
||||
fd_ = fd;
|
||||
|
||||
@ -16,8 +16,15 @@ class ASocket
|
||||
CGI_SOCKET
|
||||
};
|
||||
|
||||
enum class IOState : uint32_t
|
||||
{
|
||||
NONE = 0,
|
||||
READ = 1 << 0,
|
||||
WRITE = 1 << 1,
|
||||
};
|
||||
|
||||
ASocket() = delete;
|
||||
explicit ASocket(int fd);
|
||||
explicit ASocket(int fd, IOState state = IOState::READ);
|
||||
|
||||
ASocket(const ASocket &other) = delete;
|
||||
ASocket &operator=(const ASocket &other) = delete;
|
||||
@ -28,6 +35,8 @@ class ASocket
|
||||
|
||||
[[nodiscard]] virtual Type getType() const noexcept = 0;
|
||||
[[nodiscard]] int getFd() const noexcept;
|
||||
[[nodiscard]] IOState getEvent() const noexcept;
|
||||
[[nodiscard]] bool isDirty() const noexcept;
|
||||
|
||||
void callback() const;
|
||||
void setCallback(std::function<void()> callback);
|
||||
@ -35,11 +44,16 @@ class ASocket
|
||||
virtual ssize_t read(void *buf, size_t len) const;
|
||||
virtual ssize_t write(const void *buf, size_t len) const;
|
||||
|
||||
void setIOState(IOState event);
|
||||
void processed();
|
||||
|
||||
protected:
|
||||
void setNonBlocking() const;
|
||||
void setFd(int fd);
|
||||
|
||||
private:
|
||||
int fd_;
|
||||
bool dirty_ = false;
|
||||
IOState ioState_;
|
||||
std::function<void()> callback_ = nullptr;
|
||||
};
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
#include <webserv/log/Log.hpp> // for LOCATION, Log
|
||||
#include <webserv/socket/ASocket.hpp> // for ASocket
|
||||
#include <webserv/socket/CgiSocket.hpp>
|
||||
|
||||
#include <unistd.h>
|
||||
|
||||
CgiSocket::CgiSocket(int fd) : ASocket(fd)
|
||||
CgiSocket::CgiSocket(int fd, ASocket::IOState event) : ASocket(fd, event)
|
||||
{
|
||||
Log::trace(LOCATION);
|
||||
}
|
||||
@ -13,7 +14,7 @@ ASocket::Type CgiSocket::getType() const noexcept
|
||||
return ASocket::Type::CGI_SOCKET;
|
||||
}
|
||||
|
||||
ssize_t CgiSocket::read(void *buf, size_t len) const
|
||||
ssize_t CgiSocket::read(void *buf, size_t len) const
|
||||
{
|
||||
Log::trace(LOCATION);
|
||||
ssize_t bytesRead = ::read(getFd(), buf, len);
|
||||
@ -24,7 +25,7 @@ ssize_t CgiSocket::read(void *buf, size_t len) const
|
||||
return bytesRead;
|
||||
}
|
||||
|
||||
ssize_t CgiSocket::write(const void *buf, size_t len) const
|
||||
ssize_t CgiSocket::write(const void *buf, size_t len) const
|
||||
{
|
||||
Log::trace(LOCATION);
|
||||
ssize_t bytesSent = ::write(getFd(), buf, len);
|
||||
|
||||
@ -9,10 +9,10 @@
|
||||
class CgiSocket : public ASocket
|
||||
{
|
||||
public:
|
||||
explicit CgiSocket(int fd);
|
||||
explicit CgiSocket(int fd, ASocket::IOState event);
|
||||
|
||||
[[nodiscard]] ASocket::Type getType() const noexcept override;
|
||||
|
||||
ssize_t read(void *buf, size_t len) const override;
|
||||
ssize_t write(const void *buf, size_t len) const override;
|
||||
ssize_t read(void *buf, size_t len) const override;
|
||||
ssize_t write(const void *buf, size_t len) const override;
|
||||
};
|
||||
@ -51,7 +51,7 @@ void ServerSocket::bind(const std::string &host, const int port) const
|
||||
struct sockaddr_in address{};
|
||||
address.sin_family = AF_INET;
|
||||
address.sin_addr.s_addr = inet_addr(host.c_str());
|
||||
address.sin_port = htons(port);
|
||||
address.sin_port = htons(static_cast<uint16_t>(port));
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
|
||||
if (::bind(getFd(), reinterpret_cast<struct sockaddr *>(&address), sizeof(address)) < 0)
|
||||
|
||||
@ -1,11 +1,17 @@
|
||||
|
||||
#include "webserv/socket/ASocket.hpp"
|
||||
|
||||
#include <webserv/utils/utils.hpp>
|
||||
|
||||
#include <cstdint>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include <sys/epoll.h>
|
||||
|
||||
namespace utils
|
||||
{
|
||||
size_t stoul(const std::string &str, int base)
|
||||
@ -132,4 +138,19 @@ std::string implode(const std::vector<std::string> &elements, const std::string
|
||||
}
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
uint32_t stateToEpoll(const ASocket::IOState &event)
|
||||
{
|
||||
uint32_t epollEvents = 0;
|
||||
using EventType = std::underlying_type_t<ASocket::IOState>;
|
||||
if ((static_cast<EventType>(event) & static_cast<EventType>(ASocket::IOState::READ)) != 0U)
|
||||
{
|
||||
epollEvents |= EPOLLIN;
|
||||
}
|
||||
if ((static_cast<EventType>(event) & static_cast<EventType>(ASocket::IOState::WRITE)) != 0U)
|
||||
{
|
||||
epollEvents |= EPOLLOUT;
|
||||
}
|
||||
return epollEvents;
|
||||
}
|
||||
} // namespace utils
|
||||
@ -1,13 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "webserv/socket/ASocket.hpp"
|
||||
|
||||
#include <cstddef> // for size_t
|
||||
#include <string> // for string
|
||||
#include <cstdint>
|
||||
#include <string> // for string
|
||||
#include <vector>
|
||||
|
||||
namespace utils
|
||||
{
|
||||
size_t stoul(const std::string &str, int base = 10);
|
||||
// std::string trimSemi(const std::string &str);
|
||||
std::string trim(const std::string &str, const std::string &charset = " \t\n\r");
|
||||
size_t findCorrespondingClosingBrace(const std::string &str, size_t openPos);
|
||||
void removeEmptyLines(std::string &str);
|
||||
@ -15,4 +17,6 @@ void removeComments(std::string &str);
|
||||
|
||||
std::vector<std::string> split(const std::string &str, char delimiter);
|
||||
std::string implode(const std::vector<std::string> &elements, const std::string &delimiter);
|
||||
|
||||
uint32_t stateToEpoll(const ASocket::IOState &event);
|
||||
} // namespace utils
|
||||
Loading…
Reference in New Issue
Block a user