330 lines
10 KiB
C++
330 lines
10 KiB
C++
#include <atomic>
|
|
#include <catch2/catch_test_macros.hpp>
|
|
|
|
#include <matador/query/query.hpp>
|
|
|
|
#include "matador/sql/connection_pool.hpp"
|
|
#include "matador/sql/error_code.hpp"
|
|
#include "matador/sql/statement_cache.hpp"
|
|
|
|
#include "matador/utils/message_bus.hpp"
|
|
|
|
#include "../backend/test_backend_service.hpp"
|
|
|
|
#include "ConnectionPoolFixture.hpp"
|
|
|
|
#include <queue>
|
|
#include <random>
|
|
#include <thread>
|
|
|
|
using namespace matador::test;
|
|
using namespace matador::sql;
|
|
using namespace matador::query;
|
|
using namespace matador::utils;
|
|
|
|
class MetricsObserver {
|
|
public:
|
|
explicit MetricsObserver(message_bus &bus) {
|
|
subscriptions.push_back(bus.subscribe<statement_lock_failed_event>([this](const statement_lock_failed_event &ev) {
|
|
std::lock_guard lock(mutex_);
|
|
lock_failure_count_++;
|
|
total_lock_wait_time_ += ev.duration;
|
|
}));
|
|
subscriptions.push_back(bus.subscribe<statement_execution_event>([this](const statement_execution_event &ev) {
|
|
std::lock_guard lock(mutex_);
|
|
execution_count_++;
|
|
total_execution_time_ += ev.duration;
|
|
}));
|
|
|
|
}
|
|
|
|
std::chrono::milliseconds get_average_lock_wait_time() const {
|
|
std::lock_guard lock(mutex_);
|
|
if (lock_failure_count_ == 0) {
|
|
return std::chrono::milliseconds{0};
|
|
}
|
|
const auto millis = std::chrono::duration_cast<std::chrono::milliseconds>(total_lock_wait_time_);
|
|
return std::chrono::milliseconds(millis.count() / lock_failure_count_);
|
|
}
|
|
|
|
std::chrono::milliseconds get_average_execution_time() const {
|
|
std::lock_guard lock(mutex_);
|
|
if (execution_count_ == 0) {
|
|
return std::chrono::milliseconds{0};
|
|
}
|
|
|
|
const auto millis = std::chrono::duration_cast<std::chrono::milliseconds>(total_execution_time_);
|
|
return std::chrono::milliseconds(millis.count() / execution_count_);
|
|
}
|
|
|
|
std::chrono::milliseconds get_total_lock_wait_time() const {
|
|
std::lock_guard lock(mutex_);
|
|
return std::chrono::duration_cast<std::chrono::milliseconds>(total_lock_wait_time_);
|
|
}
|
|
|
|
std::chrono::milliseconds get_total_execution_time() const {
|
|
std::lock_guard lock(mutex_);
|
|
return std::chrono::duration_cast<std::chrono::milliseconds>(total_execution_time_);
|
|
}
|
|
|
|
size_t get_lock_failure_count() const {
|
|
std::lock_guard lock(mutex_);
|
|
return lock_failure_count_;
|
|
}
|
|
|
|
private:
|
|
std::vector<subscription> subscriptions;
|
|
mutable std::mutex mutex_;
|
|
size_t lock_failure_count_{0};
|
|
size_t execution_count_{0};
|
|
std::chrono::nanoseconds total_lock_wait_time_{0};
|
|
std::chrono::nanoseconds total_execution_time_{0};
|
|
};
|
|
|
|
class RecordingObserver final {
|
|
public:
|
|
explicit RecordingObserver(message_bus &bus) {
|
|
subscriptions.push_back(bus.subscribe<statement_accessed_event>([this](const statement_accessed_event &ev) {
|
|
std::lock_guard lock(mutex);
|
|
events.push(message::from_ref(ev));
|
|
}));
|
|
subscriptions.push_back(bus.subscribe<statement_added_event>([this](const statement_added_event &ev) {
|
|
std::lock_guard lock(mutex);
|
|
events.push(message::from_ref(ev));
|
|
}));
|
|
subscriptions.push_back(bus.subscribe<statement_evicted_event>([this](const statement_evicted_event &ev) {
|
|
std::lock_guard lock(mutex);
|
|
events.push(message::from_ref(ev));
|
|
}));
|
|
}
|
|
|
|
std::optional<message> poll() {
|
|
std::lock_guard lock(mutex);
|
|
if (events.empty()) {
|
|
return std::nullopt;
|
|
}
|
|
auto evt = events.front();
|
|
events.pop();
|
|
return evt;
|
|
}
|
|
|
|
private:
|
|
std::vector<subscription> subscriptions;
|
|
std::mutex mutex;
|
|
std::queue<message> events;
|
|
};
|
|
|
|
TEST_CASE("Test statement cache", "[statement][cache]") {
|
|
backend_provider::instance().register_backend("noop", std::make_unique<orm::test_backend_service>());
|
|
|
|
matador::utils::message_bus bus;
|
|
connection_pool pool("noop://noop.db", 4);
|
|
statement_cache cache(bus, pool, 2);
|
|
|
|
query_context ctx;
|
|
ctx.sql = "SELECT * FROM person";
|
|
|
|
REQUIRE(cache.capacity() == 2);
|
|
REQUIRE(cache.empty());
|
|
|
|
auto result = cache.acquire(ctx);
|
|
REQUIRE(result);
|
|
|
|
REQUIRE(cache.size() == 1);
|
|
REQUIRE(!cache.empty());
|
|
REQUIRE(cache.capacity() == 2);
|
|
|
|
const auto stmt = result.value();
|
|
|
|
ctx.sql = "SELECT title FROM book";
|
|
|
|
result = cache.acquire(ctx);
|
|
REQUIRE(result);
|
|
|
|
REQUIRE(cache.size() == 2);
|
|
REQUIRE(!cache.empty());
|
|
REQUIRE(cache.capacity() == 2);
|
|
|
|
ctx.sql = "SELECT name FROM author";
|
|
|
|
result = cache.acquire(ctx);
|
|
REQUIRE(result);
|
|
|
|
REQUIRE(cache.size() == 2);
|
|
REQUIRE(!cache.empty());
|
|
REQUIRE(cache.capacity() == 2);
|
|
}
|
|
|
|
TEST_CASE("Test LRU cache evicts oldest entries", "[statement][cache][evict]") {
|
|
backend_provider::instance().register_backend("noop", std::make_unique<orm::test_backend_service>());
|
|
|
|
connection_pool pool("noop://noop.db", 4);
|
|
message_bus bus;
|
|
statement_cache cache(bus, pool, 2);
|
|
RecordingObserver observer(bus);
|
|
|
|
REQUIRE(cache.capacity() == 2);
|
|
REQUIRE(cache.empty());
|
|
|
|
auto result = cache.acquire({"SELECT * FROM person"});
|
|
REQUIRE(result);
|
|
auto stmt1 = result.value();
|
|
result = cache.acquire({"SELECT title FROM book"});
|
|
REQUIRE(result);
|
|
auto stmt2 = result.value();
|
|
result = cache.acquire({"SELECT name FROM author"}); // Should evict the first statement
|
|
REQUIRE(result);
|
|
auto stmt3 = result.value();
|
|
|
|
// Trigger re-prepares of an evicted statement
|
|
result = cache.acquire({"SELECT 1"});
|
|
REQUIRE(result);
|
|
auto stmt4 = result.value();
|
|
|
|
REQUIRE(stmt1.sql() == "SELECT * FROM person");
|
|
REQUIRE(stmt2.sql() == "SELECT title FROM book");
|
|
REQUIRE(stmt3.sql() == "SELECT name FROM author");
|
|
REQUIRE(stmt4.sql() == "SELECT 1");
|
|
|
|
REQUIRE(cache.size() == 2);
|
|
REQUIRE(!cache.empty());
|
|
REQUIRE(cache.capacity() == 2);
|
|
|
|
int added = 0, evicted = 0;
|
|
while (auto e = observer.poll()) {
|
|
if (e->is<statement_added_event>()) {
|
|
added++;
|
|
}
|
|
if (e->is<statement_evicted_event>()) {
|
|
evicted++;
|
|
}
|
|
}
|
|
REQUIRE(added >= 3);
|
|
REQUIRE(evicted >= 1);
|
|
}
|
|
|
|
TEST_CASE("Test statement reuse avoids reprepare", "[statement][cache][prepare]") {
|
|
backend_provider::instance().register_backend("noop", std::make_unique<orm::test_backend_service>());
|
|
|
|
connection_pool pool("noop://noop.db", 4);
|
|
message_bus bus;
|
|
statement_cache cache(bus, pool, 2);
|
|
RecordingObserver observer(bus);
|
|
|
|
REQUIRE(cache.capacity() == 2);
|
|
REQUIRE(cache.empty());
|
|
|
|
auto result = cache.acquire({"SELECT * FROM person"});
|
|
REQUIRE(result);
|
|
auto stmt1 = result.value();
|
|
result = cache.acquire({"SELECT * FROM person"});
|
|
REQUIRE(result);
|
|
auto stmt2 = result.value();
|
|
}
|
|
|
|
// TEST_CASE("Multithreaded stress test", "[statement][cache][stress]") {
|
|
// backend_provider::instance().register_backend("noop", std::make_unique<orm::test_backend_service>());
|
|
//
|
|
// constexpr int thread_count = 16;
|
|
// constexpr int iterations = 1000;
|
|
// constexpr int sql_pool_size = 10;
|
|
//
|
|
// std::vector<std::string> sqls;
|
|
// for (int i = 0; i < sql_pool_size; ++i) {
|
|
// sqls.push_back("SELECT " + std::to_string(i));
|
|
// }
|
|
//
|
|
// connection_pool pool("noop://noop.db", 4);
|
|
// message_bus bus;
|
|
// statement_cache cache(bus, pool, 5);
|
|
// RecordingObserver observer(bus);
|
|
// MetricsObserver metrics(bus);
|
|
//
|
|
// auto start_time = std::chrono::steady_clock::now();
|
|
//
|
|
// std::atomic_int lock_failed_count{0};
|
|
// std::atomic_int exec_failed_count{0};
|
|
//
|
|
// auto worker = [&](const int tid) {
|
|
// std::mt19937 rng(tid);
|
|
// std::uniform_int_distribution dist(0, sql_pool_size - 1);
|
|
//
|
|
// for (int i = 0; i < iterations; ++i) {
|
|
// const auto& sql = sqls[dist(rng)];
|
|
// if (const auto result = cache.acquire({sql}); !result) {
|
|
// FAIL("Failed to acquire statement");
|
|
// } else {
|
|
// if (const auto exec_result = result->execute(); !exec_result) {
|
|
// if (exec_result.err().ec() == error_code::STATEMENT_LOCKED) {
|
|
// ++lock_failed_count;
|
|
// } else {
|
|
// ++exec_failed_count;
|
|
// }
|
|
// }
|
|
// }
|
|
// }
|
|
// };
|
|
//
|
|
// std::vector<std::thread> threads;
|
|
// for (int i = 0; i < thread_count; ++i) {
|
|
// threads.emplace_back(worker, i);
|
|
// }
|
|
//
|
|
// for (auto& t : threads) {
|
|
// t.join();
|
|
// }
|
|
//
|
|
// auto end_time = std::chrono::steady_clock::now();
|
|
// auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
|
|
//
|
|
// std::cout << "[Performance] Executed " << (thread_count * iterations) << " statements in " << duration.count() << " ms (lock failed: " << lock_failed_count << ", execute failed: " << exec_failed_count << ")\n";
|
|
//
|
|
// std::cout << "Average lock wait time: " << metrics.get_average_lock_wait_time().count() << "ms\n";
|
|
// std::cout << "Total lock wait time: " << metrics.get_total_lock_wait_time().count() << "ms\n";
|
|
// std::cout << "Average execution time: " << metrics.get_average_execution_time().count() << "ms\n";
|
|
// std::cout << "Total execution time: " << metrics.get_total_execution_time().count() << "ms\n";
|
|
// std::cout << "Number of lock failures: " << metrics.get_lock_failure_count() << "\n";
|
|
//
|
|
// // Some events should be generated
|
|
// int accessed = 0;
|
|
// while (auto e = observer.poll()) {
|
|
// if (e->is<statement_accessed_event>()) accessed++;
|
|
// }
|
|
// REQUIRE(accessed > 0);
|
|
// }
|
|
|
|
TEST_CASE("Race condition simulation with mixed access", "[statement_cache][race]") {
|
|
backend_provider::instance().register_backend("noop", std::make_unique<orm::test_backend_service>());
|
|
|
|
connection_pool pool("noop://noop.db", 4);
|
|
message_bus bus;
|
|
statement_cache cache(bus, pool, 5);
|
|
|
|
constexpr int threads = 8;
|
|
constexpr int operations = 500;
|
|
|
|
auto task = [&](int /*id*/) {
|
|
for (int i = 0; i < operations; ++i) {
|
|
const auto sql = "SELECT " + std::to_string(i % 10);
|
|
if (const auto result = cache.acquire({sql}); !result) {
|
|
FAIL("Statement should not be available");
|
|
}
|
|
|
|
// if (i % 50 == 0) {
|
|
// cache.cleanup_expired_connections();
|
|
// }
|
|
}
|
|
};
|
|
|
|
std::vector<std::thread> jobs;
|
|
jobs.reserve(threads);
|
|
for (int i = 0; i < threads; ++i) {
|
|
jobs.emplace_back(task, i);
|
|
}
|
|
|
|
for (auto& t : jobs) {
|
|
t.join();
|
|
}
|
|
|
|
SUCCEED("Race simulation completed successfully without crash");
|
|
} |