query/test/orm/sql/StatementCacheTest.cpp

330 lines
10 KiB
C++

#include <atomic>
#include <catch2/catch_test_macros.hpp>
#include <matador/query/query.hpp>
#include "matador/sql/connection_pool.hpp"
#include "matador/sql/error_code.hpp"
#include "matador/sql/statement_cache.hpp"
#include "matador/utils/message_bus.hpp"
#include "../backend/test_backend_service.hpp"
#include "ConnectionPoolFixture.hpp"
#include <queue>
#include <random>
#include <thread>
using namespace matador::test;
using namespace matador::sql;
using namespace matador::query;
using namespace matador::utils;
class MetricsObserver {
public:
explicit MetricsObserver(message_bus &bus) {
subscriptions.push_back(bus.subscribe<statement_lock_failed_event>([this](const statement_lock_failed_event &ev) {
std::lock_guard lock(mutex_);
lock_failure_count_++;
total_lock_wait_time_ += ev.duration;
}));
subscriptions.push_back(bus.subscribe<statement_execution_event>([this](const statement_execution_event &ev) {
std::lock_guard lock(mutex_);
execution_count_++;
total_execution_time_ += ev.duration;
}));
}
std::chrono::milliseconds get_average_lock_wait_time() const {
std::lock_guard lock(mutex_);
if (lock_failure_count_ == 0) {
return std::chrono::milliseconds{0};
}
const auto millis = std::chrono::duration_cast<std::chrono::milliseconds>(total_lock_wait_time_);
return std::chrono::milliseconds(millis.count() / lock_failure_count_);
}
std::chrono::milliseconds get_average_execution_time() const {
std::lock_guard lock(mutex_);
if (execution_count_ == 0) {
return std::chrono::milliseconds{0};
}
const auto millis = std::chrono::duration_cast<std::chrono::milliseconds>(total_execution_time_);
return std::chrono::milliseconds(millis.count() / execution_count_);
}
std::chrono::milliseconds get_total_lock_wait_time() const {
std::lock_guard lock(mutex_);
return std::chrono::duration_cast<std::chrono::milliseconds>(total_lock_wait_time_);
}
std::chrono::milliseconds get_total_execution_time() const {
std::lock_guard lock(mutex_);
return std::chrono::duration_cast<std::chrono::milliseconds>(total_execution_time_);
}
size_t get_lock_failure_count() const {
std::lock_guard lock(mutex_);
return lock_failure_count_;
}
private:
std::vector<subscription> subscriptions;
mutable std::mutex mutex_;
size_t lock_failure_count_{0};
size_t execution_count_{0};
std::chrono::nanoseconds total_lock_wait_time_{0};
std::chrono::nanoseconds total_execution_time_{0};
};
class RecordingObserver final {
public:
explicit RecordingObserver(message_bus &bus) {
subscriptions.push_back(bus.subscribe<statement_accessed_event>([this](const statement_accessed_event &ev) {
std::lock_guard lock(mutex);
events.push(message::from_ref(ev));
}));
subscriptions.push_back(bus.subscribe<statement_added_event>([this](const statement_added_event &ev) {
std::lock_guard lock(mutex);
events.push(message::from_ref(ev));
}));
subscriptions.push_back(bus.subscribe<statement_evicted_event>([this](const statement_evicted_event &ev) {
std::lock_guard lock(mutex);
events.push(message::from_ref(ev));
}));
}
std::optional<message> poll() {
std::lock_guard lock(mutex);
if (events.empty()) {
return std::nullopt;
}
auto evt = events.front();
events.pop();
return evt;
}
private:
std::vector<subscription> subscriptions;
std::mutex mutex;
std::queue<message> events;
};
TEST_CASE("Test statement cache", "[statement][cache]") {
backend_provider::instance().register_backend("noop", std::make_unique<orm::test_backend_service>());
matador::utils::message_bus bus;
connection_pool pool("noop://noop.db", 4);
statement_cache cache(bus, pool, 2);
query_context ctx;
ctx.sql = "SELECT * FROM person";
REQUIRE(cache.capacity() == 2);
REQUIRE(cache.empty());
auto result = cache.acquire(ctx);
REQUIRE(result);
REQUIRE(cache.size() == 1);
REQUIRE(!cache.empty());
REQUIRE(cache.capacity() == 2);
const auto stmt = result.value();
ctx.sql = "SELECT title FROM book";
result = cache.acquire(ctx);
REQUIRE(result);
REQUIRE(cache.size() == 2);
REQUIRE(!cache.empty());
REQUIRE(cache.capacity() == 2);
ctx.sql = "SELECT name FROM author";
result = cache.acquire(ctx);
REQUIRE(result);
REQUIRE(cache.size() == 2);
REQUIRE(!cache.empty());
REQUIRE(cache.capacity() == 2);
}
TEST_CASE("Test LRU cache evicts oldest entries", "[statement][cache][evict]") {
backend_provider::instance().register_backend("noop", std::make_unique<orm::test_backend_service>());
connection_pool pool("noop://noop.db", 4);
message_bus bus;
statement_cache cache(bus, pool, 2);
RecordingObserver observer(bus);
REQUIRE(cache.capacity() == 2);
REQUIRE(cache.empty());
auto result = cache.acquire({"SELECT * FROM person"});
REQUIRE(result);
auto stmt1 = result.value();
result = cache.acquire({"SELECT title FROM book"});
REQUIRE(result);
auto stmt2 = result.value();
result = cache.acquire({"SELECT name FROM author"}); // Should evict the first statement
REQUIRE(result);
auto stmt3 = result.value();
// Trigger re-prepares of an evicted statement
result = cache.acquire({"SELECT 1"});
REQUIRE(result);
auto stmt4 = result.value();
REQUIRE(stmt1.sql() == "SELECT * FROM person");
REQUIRE(stmt2.sql() == "SELECT title FROM book");
REQUIRE(stmt3.sql() == "SELECT name FROM author");
REQUIRE(stmt4.sql() == "SELECT 1");
REQUIRE(cache.size() == 2);
REQUIRE(!cache.empty());
REQUIRE(cache.capacity() == 2);
int added = 0, evicted = 0;
while (auto e = observer.poll()) {
if (e->is<statement_added_event>()) {
added++;
}
if (e->is<statement_evicted_event>()) {
evicted++;
}
}
REQUIRE(added >= 3);
REQUIRE(evicted >= 1);
}
TEST_CASE("Test statement reuse avoids reprepare", "[statement][cache][prepare]") {
backend_provider::instance().register_backend("noop", std::make_unique<orm::test_backend_service>());
connection_pool pool("noop://noop.db", 4);
message_bus bus;
statement_cache cache(bus, pool, 2);
RecordingObserver observer(bus);
REQUIRE(cache.capacity() == 2);
REQUIRE(cache.empty());
auto result = cache.acquire({"SELECT * FROM person"});
REQUIRE(result);
auto stmt1 = result.value();
result = cache.acquire({"SELECT * FROM person"});
REQUIRE(result);
auto stmt2 = result.value();
}
// TEST_CASE("Multithreaded stress test", "[statement][cache][stress]") {
// backend_provider::instance().register_backend("noop", std::make_unique<orm::test_backend_service>());
//
// constexpr int thread_count = 16;
// constexpr int iterations = 1000;
// constexpr int sql_pool_size = 10;
//
// std::vector<std::string> sqls;
// for (int i = 0; i < sql_pool_size; ++i) {
// sqls.push_back("SELECT " + std::to_string(i));
// }
//
// connection_pool pool("noop://noop.db", 4);
// message_bus bus;
// statement_cache cache(bus, pool, 5);
// RecordingObserver observer(bus);
// MetricsObserver metrics(bus);
//
// auto start_time = std::chrono::steady_clock::now();
//
// std::atomic_int lock_failed_count{0};
// std::atomic_int exec_failed_count{0};
//
// auto worker = [&](const int tid) {
// std::mt19937 rng(tid);
// std::uniform_int_distribution dist(0, sql_pool_size - 1);
//
// for (int i = 0; i < iterations; ++i) {
// const auto& sql = sqls[dist(rng)];
// if (const auto result = cache.acquire({sql}); !result) {
// FAIL("Failed to acquire statement");
// } else {
// if (const auto exec_result = result->execute(); !exec_result) {
// if (exec_result.err().ec() == error_code::STATEMENT_LOCKED) {
// ++lock_failed_count;
// } else {
// ++exec_failed_count;
// }
// }
// }
// }
// };
//
// std::vector<std::thread> threads;
// for (int i = 0; i < thread_count; ++i) {
// threads.emplace_back(worker, i);
// }
//
// for (auto& t : threads) {
// t.join();
// }
//
// auto end_time = std::chrono::steady_clock::now();
// auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
//
// std::cout << "[Performance] Executed " << (thread_count * iterations) << " statements in " << duration.count() << " ms (lock failed: " << lock_failed_count << ", execute failed: " << exec_failed_count << ")\n";
//
// std::cout << "Average lock wait time: " << metrics.get_average_lock_wait_time().count() << "ms\n";
// std::cout << "Total lock wait time: " << metrics.get_total_lock_wait_time().count() << "ms\n";
// std::cout << "Average execution time: " << metrics.get_average_execution_time().count() << "ms\n";
// std::cout << "Total execution time: " << metrics.get_total_execution_time().count() << "ms\n";
// std::cout << "Number of lock failures: " << metrics.get_lock_failure_count() << "\n";
//
// // Some events should be generated
// int accessed = 0;
// while (auto e = observer.poll()) {
// if (e->is<statement_accessed_event>()) accessed++;
// }
// REQUIRE(accessed > 0);
// }
TEST_CASE("Race condition simulation with mixed access", "[statement_cache][race]") {
backend_provider::instance().register_backend("noop", std::make_unique<orm::test_backend_service>());
connection_pool pool("noop://noop.db", 4);
message_bus bus;
statement_cache cache(bus, pool, 5);
constexpr int threads = 8;
constexpr int operations = 500;
auto task = [&](int /*id*/) {
for (int i = 0; i < operations; ++i) {
const auto sql = "SELECT " + std::to_string(i % 10);
if (const auto result = cache.acquire({sql}); !result) {
FAIL("Statement should not be available");
}
// if (i % 50 == 0) {
// cache.cleanup_expired_connections();
// }
}
};
std::vector<std::thread> jobs;
jobs.reserve(threads);
for (int i = 0; i < threads; ++i) {
jobs.emplace_back(task, i);
}
for (auto& t : jobs) {
t.join();
}
SUCCEED("Race simulation completed successfully without crash");
}