diff --git a/include/matador/sql/error_code.hpp b/include/matador/sql/error_code.hpp index cbab3c2..f493681 100644 --- a/include/matador/sql/error_code.hpp +++ b/include/matador/sql/error_code.hpp @@ -21,6 +21,7 @@ enum class error_code : uint8_t { RESET_FAILED, OPEN_ERROR, CLOSE_ERROR, + STATEMENT_LOCKED, FAILURE }; diff --git a/source/orm/sql/statement_cache.cpp b/source/orm/sql/statement_cache.cpp index 894d872..3490491 100644 --- a/source/orm/sql/statement_cache.cpp +++ b/source/orm/sql/statement_cache.cpp @@ -3,6 +3,8 @@ #include "matador/sql/error_code.hpp" #include "matador/sql/connection_pool.hpp" +#include + namespace matador::sql { namespace internal { class statement_cache_proxy final : public statement_proxy { @@ -11,11 +13,49 @@ public: : statement_proxy(std::move(stmt)) {} utils::result execute(interface::parameter_binder& bindings) override { + if (!try_lock()) { + return utils::failure(utils::error{ + error_code::STATEMENT_LOCKED, + "Failed to execute statement because it is already in use" + }); + } + + auto guard = statement_guard(*this); return statement_->execute(bindings); } utils::result, utils::error> fetch(interface::parameter_binder& bindings) override { + if (!try_lock()) { + return utils::failure(utils::error{ + error_code::STATEMENT_LOCKED, + "Failed to execute statement because it is already in use" + }); + } + + auto guard = statement_guard(*this); return statement_->fetch(bindings); } + +protected: + [[nodiscard]] bool try_lock() { + bool expected = false; + return locked_.compare_exchange_strong(expected, true); + } + + void unlock() { + locked_.store(false); + } + +private: + struct statement_guard { + explicit statement_guard(statement_cache_proxy &statement_proxy) + : proxy(statement_proxy) {} + ~statement_guard() { proxy.unlock(); } + + statement_cache_proxy &proxy; + }; + +private: + std::atomic_bool locked_{false}; }; } diff --git a/test/orm/backend/test_connection.cpp b/test/orm/backend/test_connection.cpp index aa3070d..556f367 100644 --- a/test/orm/backend/test_connection.cpp +++ b/test/orm/backend/test_connection.cpp @@ -3,7 +3,6 @@ #include "test_result_reader.hpp" #include "matador/sql/query_context.hpp" -#include "matador/sql/record.hpp" #include "matador/sql/internal/query_result_impl.hpp"