#include "postgres_connection.hpp" #include "postgres_error.hpp" #include "postgres_result_reader.hpp" #include "postgres_statement.hpp" #include "matador/object/attribute_definition.hpp" #include "matador/sql/error_code.hpp" #include "matador/sql/record.hpp" #include "matador/sql/internal/query_result_impl.hpp" #include #include namespace matador::backends::postgres { postgres_connection::string_to_int_map postgres_connection::statement_name_map_{}; postgres_connection::postgres_connection(const sql::connection_info &info) : connection_impl(info) { } utils::result postgres_connection::open() { if (conn_ != nullptr) { return utils::ok(); } const std::string connection( "user=" + info().user + " password=" + info().password + " host=" + info().hostname + " dbname=" + info().database + " port=" + std::to_string(info().port)); conn_ = PQconnectdb(connection.c_str()); if (PQstatus(conn_) == CONNECTION_BAD) { const std::string msg = PQerrorMessage(conn_); PQfinish(conn_); conn_ = nullptr; return utils::failure(make_error(sql::error_code::OPEN_ERROR, nullptr, conn_, "Failed to connect")); } return utils::ok(); } utils::result postgres_connection::close() { if (conn_) { PQfinish(conn_); conn_ = nullptr; } return utils::ok(); } utils::result postgres_connection::is_open() const { return utils::ok(conn_ != nullptr); } utils::result postgres_connection::is_valid() const { return utils::ok(PQstatus(conn_) == CONNECTION_OK); } utils::result postgres_connection::client_version() const { const auto client_version = PQlibVersion(); return utils::ok(utils::version{ static_cast(client_version / 10000), static_cast((client_version % 10000) / 100), static_cast(client_version % 100) }); } utils::result postgres_connection::server_version() const { const auto server_version = PQserverVersion(conn_); if (server_version == 0) { return utils::failure(make_error(sql::error_code::FAILURE, nullptr, conn_, "Failed to get server version")); } return utils::ok(utils::version{ static_cast(server_version / 10000), static_cast((server_version % 10000) / 100), static_cast(server_version % 100) }); } utils::basic_type oid2type(Oid oid); utils::result, utils::error> postgres_connection::fetch(const sql::query_context &context) { PGresult *res = PQexec(conn_, context.sql.c_str()); if (is_result_error(res)) { return utils::failure(make_error(sql::error_code::FETCH_FAILED, res, conn_, "Failed to fetch", context.sql)); } std::vector prototype = context.prototype; const int num_col = PQnfields(res); if (prototype.size() != static_cast(num_col)) { return utils::failure(make_error(sql::error_code::FETCH_FAILED, res, conn_, "Number of received columns doesn't match expected columns.", context.sql)); } for (int i = 0; i < num_col; ++i) { if (!prototype.at(i).is_null()) { continue; } const auto type = oid2type(PQftype(res, i)); // const char *col_name = PQfname(res, i); // const auto size = PQfmod(res, i); prototype.at(i).change_type(type); } return utils::ok(std::make_unique(std::make_unique(res), prototype)); } std::string postgres_connection::generate_statement_name(const sql::query_context &query) { std::stringstream name; name << query.table.name() << "_" << query.command_name; auto result = postgres_connection::statement_name_map_.find(name.str()); if (result == postgres_connection::statement_name_map_.end()) { result = postgres_connection::statement_name_map_.insert(std::make_pair(name.str(), 0)).first; } name << "_" << ++result->second; return name.str(); } utils::result, utils::error> postgres_connection::prepare(const sql::query_context &context) { auto statement_name = postgres_connection::generate_statement_name(context); PGresult *result = PQprepare(conn_, statement_name.c_str(), context.sql.c_str(), static_cast(context.bind_vars.size()), nullptr); if (is_result_error(result)) { return utils::failure(make_error(sql::error_code::PREPARE_FAILED, result, conn_, "Failed to prepare", context.sql)); } std::unique_ptr s(std::make_unique(conn_, statement_name, context)); return utils::ok(std::move(s)); } utils::result postgres_connection::execute(const std::string &stmt) { PGresult *res = PQexec(conn_, stmt.c_str()); if (const auto status = PQresultStatus(res); status != PGRES_COMMAND_OK && status != PGRES_TUPLES_OK) { return utils::failure(make_error(sql::error_code::FAILURE, res, conn_, "Failed to execute", stmt)); } const auto affected_rows = utils::to(PQcmdTuples(res)); PQclear(res); return utils::ok(static_cast(affected_rows)); } utils::basic_type oid2type(const Oid oid) { switch (oid) { case 16: return utils::basic_type::type_bool; case 17: return utils::basic_type::type_blob; case 18: return utils::basic_type::type_int8; case 21: return utils::basic_type::type_int16; case 23: return utils::basic_type::type_int32; case 20: return utils::basic_type::type_int64; case 25: return utils::basic_type::type_text; case 1043: return utils::basic_type::type_varchar; case 700: return utils::basic_type::type_float; case 701: return utils::basic_type::type_double; case 1082: return utils::basic_type::type_date; case 1114: return utils::basic_type::type_time; default: return utils::basic_type::type_null; } } utils::basic_type string2type(const char *type) { if (strcmp(type, "int2") == 0) { return utils::basic_type::type_int16; } else if (strcmp(type, "int4") == 0) { return utils::basic_type::type_int32; } else if (strcmp(type, "int8") == 0) { return utils::basic_type::type_int64; } else if (strcmp(type, "bool") == 0) { return utils::basic_type::type_bool; } else if (strcmp(type, "date") == 0) { return utils::basic_type::type_date; } else if (strcmp(type, "timestamp") == 0) { return utils::basic_type::type_time; } else if (strcmp(type, "float4") == 0) { return utils::basic_type::type_float; } else if (strcmp(type, "float8") == 0) { return utils::basic_type::type_double; } else if (strncmp(type, "varchar", 7) == 0) { return utils::basic_type::type_varchar; } else if (strcmp(type, "character varying") == 0) { return utils::basic_type::type_varchar; } else if (strcmp(type, "text") == 0) { return utils::basic_type::type_text; } else if (strcmp(type, "bytea") == 0) { return utils::basic_type::type_blob; } else { return utils::basic_type::type_null; } } utils::result, utils::error> postgres_connection::describe(const std::string &table) { const std::string stmt( "SELECT ordinal_position, column_name, udt_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_schema='public' AND table_name='" + table + "'"); PGresult *res = PQexec(conn_, stmt.c_str()); if (is_result_error(res)) { return utils::failure(make_error(sql::error_code::DESCRIBE_FAILED, res, conn_, "Failed to describe", stmt)); } postgres_result_reader reader(res); std::vector prototype; while (auto fetched = reader.fetch()) { if (!fetched.is_ok()) { return utils::failure(fetched.release_error()); } if (!*fetched) { break; } char *end = nullptr; // Todo: Handle error auto index = strtoul(reader.column(0), &end, 10) - 1; std::string name = reader.column(1); // Todo: extract size auto type = (string2type(reader.column(2))); end = nullptr; object::null_option_type null_opt{object::null_option_type::NULLABLE}; if (strtoul(reader.column(4), &end, 10) == 0) { null_opt = object::null_option_type::NOT_NULL; } // f.default_value(res->column(4)); prototype.emplace_back(name, type, utils::null_attributes, null_opt, index); } return utils::ok(prototype); } utils::result postgres_connection::exists(const std::string &schema_name, const std::string &table_name) { const std::string stmt( "SELECT 1 FROM information_schema.tables WHERE table_schema = '" + schema_name + "' AND table_name = '" + table_name + "'"); PGresult *res = PQexec(conn_, stmt.c_str()); if (is_result_error(res)) { return utils::failure(make_error(sql::error_code::TABLE_EXISTS_FAILED, res, conn_, "Failed check if table exists", stmt)); } const auto result = utils::to(PQcmdTuples(res)); if (!result) { return utils::failure(make_error(sql::error_code::FAILURE, res, conn_, "Failed to convert result value", stmt)); } return utils::ok(*result == 1); } std::string postgres_connection::to_escaped_string(const utils::blob& value) const { size_t escapedDataLength; unsigned char *escapedData = PQescapeByteaConn(conn_, value.data(), value.size(), &escapedDataLength); return {reinterpret_cast(escapedData), escapedDataLength-1}; } } extern "C" { MATADOR_POSTGRES_API matador::sql::connection_impl *create_database(const matador::sql::connection_info &info) { return new matador::backends::postgres::postgres_connection(info); } MATADOR_POSTGRES_API void destroy_database(matador::sql::connection_impl *db) { delete db; } }