Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/sqlgen/postgres/Iterator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ class SQLGEN_API Iterator {

Iterator& operator=(Iterator&& _other) noexcept;

static rfl::Result<Ref<Iterator>> make(const std::string& _sql,
const Conn& _conn) noexcept {
try {
return Ref<Iterator>::make(_sql, _conn);
} catch (const std::exception& e) {
return error(e.what());
}
}

private:
static std::string make_cursor_name() {
// TODO: Create unique cursor names.
Expand Down
9 changes: 9 additions & 0 deletions include/sqlgen/postgres/PostgresV2Connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ class SQLGEN_API PostgresV2Connection {
static rfl::Result<PostgresV2Connection> make(
const std::string& _conn_str) noexcept;

static rfl::Result<PostgresV2Connection> make(PGconn* _ptr) noexcept {
try {
return PostgresV2Connection(_ptr);
} catch (const std::exception& e) {
return rfl::error("Failed to connect to postgres: " +
std::string(e.what()));
}
}

PGconn* ptr() const { return ptr_.get(); }

private:
Expand Down
9 changes: 9 additions & 0 deletions include/sqlgen/postgres/PostgresV2Result.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ class SQLGEN_API PostgresV2Result {
static rfl::Result<PostgresV2Result> make(
const std::string& _query, const PostgresV2Connection& _conn) noexcept;

static rfl::Result<PostgresV2Result> make(PGresult* _ptr) noexcept {
try {
return PostgresV2Result(_ptr);
} catch (const std::exception& e) {
return rfl::error("Failed to retrieve result from postgres: " +
std::string(e.what()));
}
}

PGresult* ptr() const { return ptr_.get(); }

private:
Expand Down
143 changes: 78 additions & 65 deletions src/sqlgen/postgres/Connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,31 @@ Result<Nothing> Connection::end_write() {
if (PQputCopyEnd(conn_.ptr(), NULL) == -1) {
return error(PQerrorMessage(conn_.ptr()));
}
const auto res = PostgresV2Result(PQgetResult(conn_.ptr()));
if (PQresultStatus(res.ptr()) != PGRES_COMMAND_OK) {
return error(PQerrorMessage(conn_.ptr()));
}
return Nothing{};
return PostgresV2Result::make(PQgetResult(conn_.ptr()))
.and_then([&](auto&& res) -> Result<Nothing> {
if (PQresultStatus(res.ptr()) != PGRES_COMMAND_OK) {
return error(PQerrorMessage(conn_.ptr()));
}
return Nothing{};
});
}

std::list<Notification> Connection::get_notifications() noexcept {
std::list<Notification> notices;

// Safe to call even if no data — just returns true
if (!PQconsumeInput(conn_.ptr())) {
// Note: In pure wait/consume pattern, this should rarely happen if socket is healthy
// But we don't error here — just skip
// Note: In pure wait/consume pattern, this should rarely happen if socket
// is healthy But we don't error here — just skip
return notices;
}

PGnotify* notify;
while ((notify = PQnotifies(conn_.ptr())) != nullptr) {
notices.push_back({
.channel = std::string(notify->relname),
.payload = notify->extra[0] ? std::string(notify->extra) : "",
.backend_pid = notify->be_pid
});
notices.push_back(
{.channel = std::string(notify->relname),
.payload = notify->extra[0] ? std::string(notify->extra) : "",
.backend_pid = notify->be_pid});
PQfreemem(notify);
}

Expand All @@ -83,16 +84,19 @@ rfl::Result<Nothing> Connection::unlisten(const std::string& channel) noexcept {
return execute(sql);
}

rfl::Result<Nothing> Connection::notify(const std::string& channel, const std::string& payload) noexcept {
rfl::Result<Nothing> Connection::notify(const std::string& channel,
const std::string& payload) noexcept {
if (!is_valid_channel_name(channel)) {
return error("Invalid channel name");
}

auto* escaped_payload = PQescapeLiteral(conn_.ptr(), payload.c_str(), payload.size());
auto* escaped_payload =
PQescapeLiteral(conn_.ptr(), payload.c_str(), payload.size());
if (!escaped_payload) {
return error("Failed to escape NOTIFY payload");
}
const std::string sql = "NOTIFY " + channel + ", " + std::string(escaped_payload);
const std::string sql =
"NOTIFY " + channel + ", " + std::string(escaped_payload);
PQfreemem(escaped_payload);

auto result = execute(sql);
Expand All @@ -116,55 +120,64 @@ Result<Nothing> Connection::insert_impl(

const auto sql = to_sql_impl(_stmt);

const auto res = PostgresV2Result(PQprepare(
conn_.ptr(), name.c_str(), sql.c_str(), _data.at(0).size(), nullptr));

const auto status = PQresultStatus(res.ptr());

if (status != PGRES_COMMAND_OK) {
return error("Generating prepared statement for '" + sql +
"' failed: " + PQresultErrorMessage(res.ptr()));
}

std::vector<const char*> current_row(_data[0].size());

const int n_params = static_cast<int>(current_row.size());

for (size_t i = 0; i < _data.size(); ++i) {
const auto& d = _data[i];

if (d.size() != current_row.size()) {
execute("DEALLOCATE " + name + ";");
return error("Error in entry " + std::to_string(i) + ": Expected " +
std::to_string(current_row.size()) + " entries, got " +
std::to_string(d.size()));
}

for (size_t j = 0; j < d.size(); ++j) {
current_row[j] = d[j] ? d[j]->c_str() : nullptr;
}

const auto res =
PostgresV2Result(PQexecPrepared(conn_.ptr(), // conn
name.c_str(), // stmtName
n_params, // nParams
current_row.data(), // paramValues
nullptr, // paramLengths
nullptr, // paramFormats
0 // resultFormat
));

const auto status = PQresultStatus(res.ptr());

if (status != PGRES_COMMAND_OK) {
const auto err = error(std::string("Executing INSERT failed: ") +
PQresultErrorMessage(res.ptr()));
execute("DEALLOCATE " + name + ";");
return err;
}
}

return execute("DEALLOCATE " + name + ";");
return PostgresV2Result::make(PQprepare(conn_.ptr(), name.c_str(),
sql.c_str(), _data.at(0).size(),
nullptr))
.and_then([&](auto&& res) -> Result<Nothing> {
const auto status = PQresultStatus(res.ptr());

if (status != PGRES_COMMAND_OK) {
return error("Generating prepared statement for '" + sql +
"' failed: " + PQresultErrorMessage(res.ptr()));
}

std::vector<const char*> current_row(_data[0].size());

const int n_params = static_cast<int>(current_row.size());

for (size_t i = 0; i < _data.size(); ++i) {
const auto& d = _data[i];

if (d.size() != current_row.size()) {
execute("DEALLOCATE " + name + ";");
return error("Error in entry " + std::to_string(i) + ": Expected " +
std::to_string(current_row.size()) + " entries, got " +
std::to_string(d.size()));
}

for (size_t j = 0; j < d.size(); ++j) {
current_row[j] = d[j] ? d[j]->c_str() : nullptr;
}

try {
const auto res = PostgresV2Result(PQexecPrepared(
conn_.ptr(), // conn
name.c_str(), // stmtName
n_params, // nParams
current_row.data(), // paramValues
nullptr, // paramLengths
nullptr, // paramFormats
0 // resultFormat
));

const auto status = PQresultStatus(res.ptr());

if (status != PGRES_COMMAND_OK) {
const auto err = error(std::string("Executing INSERT failed: ") +
PQresultErrorMessage(res.ptr()));
execute("DEALLOCATE " + name + ";");
return err;
}
} catch (const std::exception& e) {
const auto err =
error(std::string("Executing INSERT failed: ") + e.what());
execute("DEALLOCATE " + name + ";");
return err;
}
}

return execute("DEALLOCATE " + name + ";");
});
}

rfl::Result<Ref<Connection>> Connection::make(
Expand All @@ -176,7 +189,7 @@ rfl::Result<Ref<Connection>> Connection::make(
Result<Ref<Iterator>> Connection::read_impl(
const rfl::Variant<dynamic::SelectFrom, dynamic::Union>& _query) {
const auto sql = _query.visit([](const auto& _q) { return to_sql_impl(_q); });
return Ref<Iterator>::make(sql, conn_);
return Iterator::make(sql, conn_);
}

Result<Nothing> Connection::rollback() noexcept { return execute("ROLLBACK;"); }
Expand Down
46 changes: 46 additions & 0 deletions tests/duckdb/test_error_handling.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

#include <gtest/gtest.h>

#include <rfl.hpp>
#include <rfl/json.hpp>
#include <sqlgen.hpp>
#include <sqlgen/duckdb.hpp>
#include <vector>

namespace test_error_handling {

struct Person {
sqlgen::PrimaryKey<uint32_t> id;
std::string first_name;
std::string last_name;
int age;
};

TEST(duckdb, test_error_handling) {
const auto people1 = std::vector<Person>(
{Person{
.id = 0, .first_name = "Homer", .last_name = "Simpson", .age = 45},
Person{.id = 1, .first_name = "Bart", .last_name = "Simpson", .age = 10},
Person{.id = 2, .first_name = "Lisa", .last_name = "Simpson", .age = 8},
Person{
.id = 3, .first_name = "Maggie", .last_name = "Simpson", .age = 0},
Person{
.id = 4, .first_name = "Hugo", .last_name = "Simpson", .age = 10}});

using namespace sqlgen;
using namespace sqlgen::literals;

const auto people2 =
duckdb::connect()
.and_then(write(std::ref(people1)))
.and_then(sqlgen::read<std::vector<Person>> |
where("first_name"_c.in(std::vector<std::string>())))
.value_or(std::vector<Person>({}));

const std::string expected1 = R"([])";

EXPECT_EQ(rfl::json::write(people2), expected1);
}

} // namespace test_error_handling

55 changes: 55 additions & 0 deletions tests/mysql/test_error_handling.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#ifndef SQLGEN_BUILD_DRY_TESTS_ONLY

#include <gtest/gtest.h>

#include <rfl.hpp>
#include <rfl/json.hpp>
#include <sqlgen.hpp>
#include <sqlgen/mysql.hpp>
#include <vector>

namespace test_error_handling {

struct Person {
sqlgen::PrimaryKey<uint32_t> id;
std::string first_name;
std::string last_name;
int age;
};

TEST(mysql, test_error_handling) {
const auto people1 = std::vector<Person>(
{Person{
.id = 0, .first_name = "Homer", .last_name = "Simpson", .age = 45},
Person{.id = 1, .first_name = "Bart", .last_name = "Simpson", .age = 10},
Person{.id = 2, .first_name = "Lisa", .last_name = "Simpson", .age = 8},
Person{
.id = 3, .first_name = "Maggie", .last_name = "Simpson", .age = 0},
Person{
.id = 4, .first_name = "Hugo", .last_name = "Simpson", .age = 10}});

const auto credentials = sqlgen::mysql::Credentials{.host = "localhost",
.user = "sqlgen",
.password = "password",
.dbname = "mysql"};

using namespace sqlgen;
using namespace sqlgen::literals;

const auto people2 =
mysql::connect(credentials)
.and_then(drop<Person> | if_exists)
.and_then(write(std::ref(people1)))
.and_then(sqlgen::read<std::vector<Person>> |
where("first_name"_c.in(std::vector<std::string>())) |
order_by("age"_c))
.value_or(std::vector<Person>({}));

const std::string expected1 = R"([])";

EXPECT_EQ(rfl::json::write(people2), expected1);
}

} // namespace test_error_handling

#endif
55 changes: 55 additions & 0 deletions tests/postgres/test_error_handling.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#ifndef SQLGEN_BUILD_DRY_TESTS_ONLY

#include <gtest/gtest.h>

#include <rfl.hpp>
#include <rfl/json.hpp>
#include <sqlgen.hpp>
#include <sqlgen/postgres.hpp>
#include <vector>

namespace test_error_handling {

struct Person {
sqlgen::PrimaryKey<uint32_t> id;
std::string first_name;
std::string last_name;
int age;
};

TEST(postgres, test_error_handling) {
const auto people1 = std::vector<Person>(
{Person{
.id = 0, .first_name = "Homer", .last_name = "Simpson", .age = 45},
Person{.id = 1, .first_name = "Bart", .last_name = "Simpson", .age = 10},
Person{.id = 2, .first_name = "Lisa", .last_name = "Simpson", .age = 8},
Person{
.id = 3, .first_name = "Maggie", .last_name = "Simpson", .age = 0},
Person{
.id = 4, .first_name = "Hugo", .last_name = "Simpson", .age = 10}});

const auto credentials = sqlgen::postgres::Credentials{.user = "postgres",
.password = "password",
.host = "localhost",
.dbname = "postgres"};

using namespace sqlgen;
using namespace sqlgen::literals;

/// Intentionally passing an empty vector to test error handling
const auto people2 =
postgres::connect(credentials)
.and_then(drop<Person> | if_exists)
.and_then(write(std::ref(people1)))
.and_then(sqlgen::read<std::vector<Person>> |
where("first_name"_c.in(std::vector<std::string>())))
.value_or(std::vector<Person>({}));

const std::string expected1 = R"([])";

EXPECT_EQ(rfl::json::write(people2), expected1);
}

} // namespace test_error_handling

#endif
Loading
Loading