metal-kompanion/src/dal/PgDal.cpp

1105 lines
34 KiB
C++

#include "PgDal.hpp"
#include <QSqlDatabase>
#include <QSqlDriver>
#include <QSqlError>
#include <QSqlQuery>
#include <QSqlRecord>
#include <QStringList>
#include <QMetaType>
#include <QDateTime>
#include <QTimeZone>
#include <QUrl>
#include <QUrlQuery>
#include <QVariant>
#include <algorithm>
#include <cctype>
#include <numeric>
#include <stdexcept>
namespace kom {
namespace {
bool idsContains(const std::vector<std::string>& ids, const std::string& value) {
return std::find(ids.begin(), ids.end(), value) != ids.end();
}
} // namespace
PgDal::PgDal() = default;
PgDal::~PgDal() {
closeDatabase();
}
bool PgDal::isStubDsn(const std::string& dsn) {
return dsn.empty() || dsn.rfind("stub://", 0) == 0;
}
namespace {
QVariant toSqlTimestamp(const std::chrono::system_clock::time_point& tp) {
if (tp.time_since_epoch().count() == 0) {
return QVariant(QMetaType(QMetaType::QDateTime));
}
const auto msecs = std::chrono::duration_cast<std::chrono::milliseconds>(tp.time_since_epoch()).count();
return QDateTime::fromMSecsSinceEpoch(msecs, QTimeZone::UTC);
}
QVariant toSqlTimestamp(const std::optional<std::chrono::system_clock::time_point>& tp) {
if (!tp) {
return QVariant(QMetaType(QMetaType::QDateTime));
}
return toSqlTimestamp(*tp);
}
std::chrono::system_clock::time_point fromSqlTimestamp(const QVariant& value) {
if (!value.isValid() || value.isNull()) {
return std::chrono::system_clock::time_point{};
}
QDateTime dt = value.toDateTime();
if (!dt.isValid()) {
dt = QDateTime::fromString(value.toString(), Qt::ISODateWithMs);
}
if (!dt.isValid()) {
return std::chrono::system_clock::time_point{};
}
dt = dt.toTimeZone(QTimeZone::UTC);
return std::chrono::system_clock::time_point(std::chrono::milliseconds(dt.toMSecsSinceEpoch()));
}
std::optional<std::chrono::system_clock::time_point> fromSqlTimestampOptional(const QVariant& value) {
if (!value.isValid() || value.isNull()) {
return std::nullopt;
}
return fromSqlTimestamp(value);
}
std::optional<std::chrono::system_clock::time_point> parseIsoTimestamp(const std::optional<std::string>& iso) {
if (!iso || iso->empty()) {
return std::nullopt;
}
QDateTime dt = QDateTime::fromString(QString::fromStdString(*iso), Qt::ISODateWithMs);
if (!dt.isValid()) {
dt = QDateTime::fromString(QString::fromStdString(*iso), Qt::ISODate);
}
if (!dt.isValid()) {
return std::nullopt;
}
dt = dt.toTimeZone(QTimeZone::UTC);
return std::chrono::system_clock::time_point(std::chrono::milliseconds(dt.toMSecsSinceEpoch()));
}
bool tagsMatch(const std::vector<std::string>& rowTags, const std::vector<std::string>& filterTags) {
for (const auto& tag : filterTags) {
if (std::find(rowTags.begin(), rowTags.end(), tag) == rowTags.end()) {
return false;
}
}
return true;
}
} // namespace
bool PgDal::connect(const std::string& dsn) {
dsn_ = dsn;
connected_ = true;
transactionActive_ = false;
if (isStubDsn(dsn)) {
closeDatabase();
useInMemory_ = true;
namespacesByName_.clear();
namespacesById_.clear();
items_.clear();
itemsByNamespace_.clear();
chunks_.clear();
chunksByItem_.clear();
embeddings_.clear();
nextNamespaceId_ = 1;
nextItemId_ = 1;
nextChunkId_ = 1;
nextEmbeddingId_ = 1;
return true;
}
if (openDatabase(dsn)) {
useInMemory_ = false;
return true;
}
useInMemory_ = true;
return false;
}
bool PgDal::begin() {
if (!connected_ || !hasDatabase()) {
return false;
}
if (transactionActive_) {
return true;
}
QSqlDatabase db = database();
if (!db.transaction()) {
throw std::runtime_error(db.lastError().text().toStdString());
}
transactionActive_ = true;
return true;
}
void PgDal::commit() {
if (!transactionActive_) {
return;
}
QSqlDatabase db = database();
if (!db.commit()) {
throw std::runtime_error(db.lastError().text().toStdString());
}
transactionActive_ = false;
}
void PgDal::rollback() {
if (!transactionActive_) {
return;
}
QSqlDatabase db = database();
db.rollback();
transactionActive_ = false;
}
bool PgDal::hasDatabase() const {
return !connectionName_.isEmpty() && QSqlDatabase::contains(connectionName_);
}
QSqlDatabase PgDal::database() const {
if (!hasDatabase()) {
return QSqlDatabase();
}
return QSqlDatabase::database(connectionName_);
}
void PgDal::closeDatabase() {
if (connectionName_.isEmpty()) {
return;
}
{
QSqlDatabase db = QSqlDatabase::database(connectionName_, false);
if (db.isValid()) {
if (transactionActive_) {
db.rollback();
transactionActive_ = false;
}
db.close();
}
}
QSqlDatabase::removeDatabase(connectionName_);
connectionName_.clear();
}
PgDal::ConnectionConfig PgDal::parseDsn(const std::string& dsn) const {
ConnectionConfig cfg;
QUrl url(QString::fromStdString(dsn));
if (url.scheme().isEmpty()) {
url = QUrl(QStringLiteral("postgresql://") + QString::fromStdString(dsn));
}
cfg.dbname = url.path().isEmpty() ? QStringLiteral("kompanion")
: url.path().mid(1);
cfg.host = url.host().isEmpty() ? QStringLiteral("localhost")
: url.host();
cfg.port = url.port(5432);
cfg.user = url.userName();
cfg.password = url.password();
QUrlQuery query(url);
if (query.hasQueryItem(QStringLiteral("host"))) {
const QString hostValue = query.queryItemValue(QStringLiteral("host"));
if (hostValue.startsWith(QStringLiteral("/"))) {
cfg.useSocket = true;
cfg.socketPath = hostValue;
}
}
if (cfg.useSocket && cfg.socketPath.isEmpty()) {
cfg.socketPath = QStringLiteral("/var/run/postgresql");
}
QStringList optionPairs;
const auto queryItems = query.queryItems();
for (const auto& item : queryItems) {
if (item.first == QStringLiteral("host")) continue;
optionPairs << QStringLiteral("%1=%2").arg(item.first, item.second);
}
cfg.options = optionPairs.join(QLatin1Char(';'));
return cfg;
}
bool PgDal::openDatabase(const std::string& dsn) {
if (!QSqlDatabase::isDriverAvailable(QStringLiteral("QPSQL"))) {
return false;
}
closeDatabase();
const ConnectionConfig cfg = parseDsn(dsn);
connectionName_ = QStringLiteral("kom_dal_%1").arg(reinterpret_cast<quintptr>(this), 0, 16);
QSqlDatabase db = QSqlDatabase::addDatabase(QStringLiteral("QPSQL"), connectionName_);
db.setDatabaseName(cfg.dbname);
if (!cfg.user.isEmpty()) {
db.setUserName(cfg.user);
}
if (!cfg.password.isEmpty()) {
db.setPassword(cfg.password);
}
if (cfg.useSocket) {
db.setHostName(cfg.socketPath);
} else {
db.setHostName(cfg.host);
}
if (cfg.port > 0) {
db.setPort(cfg.port);
}
if (!cfg.options.isEmpty()) {
db.setConnectOptions(cfg.options);
}
if (!db.open()) {
const std::string err = db.lastError().text().toStdString();
closeDatabase();
throw std::runtime_error("PgDal: failed to open database: " + err);
}
return true;
}
std::optional<NamespaceRow> PgDal::ensureNamespace(const std::string& name) {
if (!connected_) {
return std::nullopt;
}
if (auto existing = findNamespace(name)) {
return existing;
}
return createNamespaceWithSecret(name).first;
}
std::optional<NamespaceRow> PgDal::findNamespace(const std::string& name) const {
if (!useInMemory_ && hasDatabase()) {
return sqlFindNamespace(name);
}
auto it = namespacesByName_.find(name);
if (it == namespacesByName_.end()) {
return std::nullopt;
}
return it->second;
}
std::optional<NamespaceRow> PgDal::sqlFindNamespace(const std::string& name) const {
QSqlDatabase db = database();
QSqlQuery query(db);
query.prepare(QStringLiteral(
"SELECT id::text, name FROM namespaces WHERE name = :name"));
query.bindValue(QStringLiteral(":name"), QString::fromStdString(name));
if (!query.exec()) {
throw std::runtime_error(query.lastError().text().toStdString());
}
if (!query.next()) {
return std::nullopt;
}
NamespaceRow row;
row.id = query.value(0).toString().toStdString();
row.name = query.value(1).toString().toStdString();
return row;
}
std::pair<NamespaceRow, std::string> PgDal::createNamespaceWithSecret(const std::string& name) {
if (!connected_) {
throw std::runtime_error("PgDal not connected");
}
if (!useInMemory_ && hasDatabase()) {
return sqlCreateNamespaceWithSecret(name);
}
// In-memory implementation
auto it = namespacesByName_.find(name);
if (it != namespacesByName_.end()) {
// For in-memory, we don't have secrets, so we can't return one.
// This path should ideally not be taken in production.
return {it->second, ""};
}
NamespaceRow row;
row.id = allocateId(nextNamespaceId_, "ns_");
row.name = name;
namespacesByName_[name] = row;
namespacesById_[row.id] = row;
// Secrets are not supported in-memory for now
return {row, ""};
}
std::optional<AuthSecret> PgDal::findSecretByNamespaceId(const std::string& namespaceId) const {
if (!useInMemory_ && hasDatabase()) {
return sqlFindSecretByNamespaceId(namespaceId);
}
// In-memory implementation does not support secrets
return std::nullopt;
}
std::pair<NamespaceRow, std::string> PgDal::sqlCreateNamespaceWithSecret(const std::string& name) {
QSqlDatabase db = database();
QSqlQuery query(db);
// 1. Create the namespace
query.prepare(QStringLiteral(
"INSERT INTO namespaces (name) VALUES (:name) "
"ON CONFLICT (name) DO UPDATE SET name = EXCLUDED.name "
"RETURNING id::text, name;"));
query.bindValue(QStringLiteral(":name"), QString::fromStdString(name));
if (!query.exec() || !query.next()) {
throw std::runtime_error(query.lastError().text().toStdString());
}
NamespaceRow row;
row.id = query.value(0).toString().toStdString();
row.name = query.value(1).toString().toStdString();
// 2. Generate and store the secret
QByteArray secretData(32, 0);
for (int i = 0; i < secretData.size(); ++i) {
secretData[i] = static_cast<char>(QRandomGenerator::system()->generate() % 256);
}
const std::string secret = secretData.toHex().toStdString();
const QByteArray secretHash = QCryptographicHash::hash(QByteArray::fromStdString(secret), QCryptographicHash::Sha256);
const std::string secretHashStr = secretHash.toHex().toStdString();
sqlInsertSecret(row.id, secretHashStr);
return {row, secret};
}
void PgDal::sqlInsertSecret(const std::string& namespaceId, const std::string& secretHash) {
QSqlDatabase db = database();
QSqlQuery query(db);
query.prepare(QStringLiteral(
"INSERT INTO auth_secrets (namespace_id, secret_hash) "
"VALUES (:namespace_id::uuid, :secret_hash) "
"ON CONFLICT (namespace_id) DO UPDATE SET secret_hash = EXCLUDED.secret_hash;"));
query.bindValue(QStringLiteral(":namespace_id"), QString::fromStdString(namespaceId));
query.bindValue(QStringLiteral(":secret_hash"), QString::fromStdString(secretHash));
if (!query.exec()) {
throw std::runtime_error(query.lastError().text().toStdString());
}
}
std::optional<AuthSecret> PgDal::sqlFindSecretByNamespaceId(const std::string& namespaceId) const {
QSqlDatabase db = database();
QSqlQuery query(db);
query.prepare(QStringLiteral(
"SELECT secret_hash FROM auth_secrets WHERE namespace_id = :namespace_id::uuid"));
query.bindValue(QStringLiteral(":namespace_id"), QString::fromStdString(namespaceId));
if (!query.exec()) {
throw std::runtime_error(query.lastError().text().toStdString());
}
if (!query.next()) {
return std::nullopt;
}
AuthSecret secret;
secret.secret_hash = query.value(0).toString().toStdString();
return secret;
}
std::string PgDal::upsertItem(const ItemRow& row) {
if (!connected_) {
throw std::runtime_error("PgDal not connected");
}
if (!useInMemory_ && hasDatabase()) {
return sqlUpsertItem(row).first;
}
ItemRow stored = row;
if (stored.id.empty()) {
stored.id = allocateId(nextItemId_, "item_");
}
auto existing = items_.find(stored.id);
if (existing != items_.end()) {
stored.revision = existing->second.revision + 1;
}
items_[stored.id] = stored;
auto& bucket = itemsByNamespace_[stored.namespace_id];
if (!idsContains(bucket, stored.id)) {
bucket.push_back(stored.id);
}
return stored.id;
}
std::pair<std::string, int> PgDal::sqlUpsertItem(const ItemRow& row) {
QSqlDatabase db = database();
QSqlQuery query(db);
query.prepare(QStringLiteral(
"INSERT INTO memory_items (id, namespace_id, key, content, metadata, tags, text, created_at, expires_at) "
"VALUES (COALESCE(NULLIF(:id, '')::uuid, gen_random_uuid()), "
" :namespace_id::uuid, :key, :content, :metadata::jsonb, :tags::text[], :text, "
" COALESCE(:created_at, now()), :expires_at) "
"ON CONFLICT (id) DO UPDATE SET "
" key = EXCLUDED.key, content = EXCLUDED.content, metadata = EXCLUDED.metadata, "
" tags = EXCLUDED.tags, text = EXCLUDED.text, updated_at = now(), "
" expires_at = EXCLUDED.expires_at, "
" revision = memory_items.revision + 1 "
"RETURNING id::text, revision;"));
query.bindValue(QStringLiteral(":id"), QString::fromStdString(row.id));
query.bindValue(QStringLiteral(":namespace_id"), QString::fromStdString(row.namespace_id));
if (row.key) {
query.bindValue(QStringLiteral(":key"), QString::fromStdString(*row.key));
} else {
query.bindValue(QStringLiteral(":key"), QVariant(QMetaType(QMetaType::QString)));
}
query.bindValue(QStringLiteral(":content"), QString::fromStdString(row.content_json));
query.bindValue(QStringLiteral(":metadata"), QString::fromStdString(row.metadata_json));
query.bindValue(QStringLiteral(":tags"), QString::fromStdString(toPgArrayLiteral(row.tags)));
if (row.text) {
query.bindValue(QStringLiteral(":text"), QString::fromStdString(*row.text));
} else {
query.bindValue(QStringLiteral(":text"), QVariant(QMetaType(QMetaType::QString)));
}
query.bindValue(QStringLiteral(":created_at"), toSqlTimestamp(row.created_at));
query.bindValue(QStringLiteral(":expires_at"), toSqlTimestamp(row.expires_at));
if (!query.exec() || !query.next()) {
throw std::runtime_error(query.lastError().text().toStdString());
}
std::pair<std::string, int> result;
result.first = query.value(0).toString().toStdString();
result.second = query.value(1).toInt();
return result;
}
std::vector<std::string> PgDal::upsertChunks(const std::vector<ChunkRow>& chunks) {
if (!connected_) {
throw std::runtime_error("PgDal not connected");
}
if (!useInMemory_ && hasDatabase()) {
return sqlUpsertChunks(chunks);
}
std::vector<std::string> ids;
ids.reserve(chunks.size());
for (const auto& input : chunks) {
ChunkRow stored = input;
if (stored.item_id.empty()) {
continue;
}
if (stored.id.empty()) {
stored.id = allocateId(nextChunkId_, "chunk_");
}
chunks_[stored.id] = stored;
auto& bucket = chunksByItem_[stored.item_id];
if (!idsContains(bucket, stored.id)) {
bucket.push_back(stored.id);
}
ids.push_back(stored.id);
}
return ids;
}
std::vector<std::string> PgDal::sqlUpsertChunks(const std::vector<ChunkRow>& chunks) {
std::vector<std::string> ids;
ids.reserve(chunks.size());
QSqlDatabase db = database();
QSqlQuery query(db);
query.prepare(QStringLiteral(
"INSERT INTO memory_chunks (id, item_id, seq, content) "
"VALUES (COALESCE(NULLIF(:id, '')::uuid, gen_random_uuid()), "
" :item_id::uuid, :seq, :content) "
"ON CONFLICT (id) DO UPDATE SET seq = EXCLUDED.seq, content = EXCLUDED.content "
"RETURNING id::text;"));
for (const auto& chunk : chunks) {
query.bindValue(QStringLiteral(":id"), QString::fromStdString(chunk.id));
query.bindValue(QStringLiteral(":item_id"), QString::fromStdString(chunk.item_id));
query.bindValue(QStringLiteral(":seq"), chunk.ord);
query.bindValue(QStringLiteral(":content"), QString::fromStdString(chunk.text));
if (!query.exec() || !query.next()) {
throw std::runtime_error(query.lastError().text().toStdString());
}
ids.push_back(query.value(0).toString().toStdString());
query.finish();
}
return ids;
}
void PgDal::upsertEmbeddings(const std::vector<EmbeddingRow>& embeddings) {
if (!connected_) {
throw std::runtime_error("PgDal not connected");
}
if (!useInMemory_ && hasDatabase()) {
sqlUpsertEmbeddings(embeddings);
return;
}
for (const auto& input : embeddings) {
if (input.chunk_id.empty()) {
continue;
}
EmbeddingRow stored = input;
if (stored.id.empty()) {
stored.id = allocateId(nextEmbeddingId_, "emb_");
}
embeddings_[stored.chunk_id] = stored;
}
}
void PgDal::sqlUpsertEmbeddings(const std::vector<EmbeddingRow>& embeddings) {
QSqlDatabase db = database();
QSqlQuery query(db);
query.prepare(QStringLiteral(
"INSERT INTO embeddings (id, chunk_id, model, dim, vector, normalized) "
"VALUES (COALESCE(NULLIF(:id, '')::uuid, gen_random_uuid()), "
" :chunk_id::uuid, :model, :dim, :vector::vector, FALSE) "
"ON CONFLICT (chunk_id, model) DO UPDATE SET "
" dim = EXCLUDED.dim, vector = EXCLUDED.vector, normalized = EXCLUDED.normalized "
"RETURNING id::text;"));
for (const auto& emb : embeddings) {
query.bindValue(QStringLiteral(":id"), QString::fromStdString(emb.id));
query.bindValue(QStringLiteral(":chunk_id"), QString::fromStdString(emb.chunk_id));
query.bindValue(QStringLiteral(":model"), QString::fromStdString(emb.model));
query.bindValue(QStringLiteral(":dim"), emb.dim);
query.bindValue(QStringLiteral(":vector"), QString::fromStdString(toPgVectorLiteral(emb.vector)));
if (!query.exec()) {
throw std::runtime_error(query.lastError().text().toStdString());
}
query.finish();
}
}
std::vector<ItemRow> PgDal::searchText(const std::string& namespaceId,
const std::string& queryText,
int limit) {
if (!useInMemory_ && hasDatabase()) {
return sqlSearchText(namespaceId, queryText, limit);
}
std::vector<ItemRow> results;
if (!connected_) return results;
auto bucketIt = itemsByNamespace_.find(namespaceId);
if (bucketIt == itemsByNamespace_.end()) return results;
const std::string loweredQuery = toLower(queryText);
for (const auto& itemId : bucketIt->second) {
auto itemIt = items_.find(itemId);
if (itemIt == items_.end()) continue;
if (!loweredQuery.empty()) {
const std::string loweredText = toLower(itemIt->second.text.value_or(std::string()));
if (loweredText.find(loweredQuery) == std::string::npos) {
continue;
}
}
results.push_back(itemIt->second);
if (static_cast<int>(results.size()) >= limit) break;
}
return results;
}
std::vector<ItemRow> PgDal::sqlSearchText(const std::string& namespaceId,
const std::string& queryText,
int limit) const {
QSqlDatabase db = database();
QSqlQuery query(db);
query.prepare(QStringLiteral(
"SELECT id::text, namespace_id::text, key, content, metadata::text, text, tags::text[], revision, created_at, expires_at "
"FROM memory_items "
"WHERE namespace_id = :ns::uuid "
" AND deleted_at IS NULL "
" AND (:query = '' OR text ILIKE '%' || :query || '%') "
"ORDER BY updated_at DESC "
"LIMIT :limit;"));
query.bindValue(QStringLiteral(":ns"), QString::fromStdString(namespaceId));
query.bindValue(QStringLiteral(":query"), QString::fromStdString(queryText));
query.bindValue(QStringLiteral(":limit"), limit);
if (!query.exec()) {
throw std::runtime_error(query.lastError().text().toStdString());
}
std::vector<ItemRow> results;
while (query.next()) {
ItemRow row;
row.id = query.value(0).toString().toStdString();
row.namespace_id = query.value(1).toString().toStdString();
if (!query.value(2).isNull()) {
row.key = query.value(2).toString().toStdString();
}
row.content_json = query.value(3).toString().toStdString();
row.metadata_json = query.value(4).toString().toStdString();
if (!query.value(5).isNull()) {
row.text = query.value(5).toString().toStdString();
}
row.tags = parsePgTextArray(query.value(6).toString());
row.revision = query.value(7).toInt();
row.created_at = fromSqlTimestamp(query.value(8));
row.expires_at = fromSqlTimestampOptional(query.value(9));
results.push_back(std::move(row));
}
return results;
}
std::vector<std::pair<std::string, float>> PgDal::searchVector(
const std::string& namespaceId,
const std::vector<float>& embedding,
int limit) {
if (!useInMemory_ && hasDatabase()) {
return sqlSearchVector(namespaceId, embedding, limit);
}
std::vector<std::pair<std::string, float>> scores;
if (!connected_ || embedding.empty()) return scores;
auto bucketIt = itemsByNamespace_.find(namespaceId);
if (bucketIt == itemsByNamespace_.end()) return scores;
for (const auto& itemId : bucketIt->second) {
auto chunkBucketIt = chunksByItem_.find(itemId);
if (chunkBucketIt == chunksByItem_.end()) continue;
float bestScore = -1.0f;
for (const auto& chunkId : chunkBucketIt->second) {
auto embIt = embeddings_.find(chunkId);
if (embIt == embeddings_.end()) continue;
const auto& storedVec = embIt->second.vector;
if (storedVec.size() != embedding.size() || storedVec.empty()) continue;
float dot = std::inner_product(storedVec.begin(), storedVec.end(), embedding.begin(), 0.0f);
if (dot > bestScore) {
bestScore = dot;
}
}
if (bestScore >= 0.0f) {
scores.emplace_back(itemId, bestScore);
}
}
std::sort(scores.begin(), scores.end(),
[](const auto& lhs, const auto& rhs) {
if (lhs.second == rhs.second) {
return lhs.first < rhs.first;
}
return lhs.second > rhs.second;
});
if (static_cast<int>(scores.size()) > limit) {
scores.resize(static_cast<std::size_t>(limit));
}
return scores;
}
std::vector<std::pair<std::string, float>> PgDal::sqlSearchVector(
const std::string& namespaceId,
const std::vector<float>& embedding,
int limit) const {
std::vector<std::pair<std::string, float>> results;
if (embedding.empty()) {
return results;
}
QSqlDatabase db = database();
QSqlQuery query(db);
query.prepare(QStringLiteral(
"SELECT i.id::text, 1 - (e.vector <=> :vector::vector) AS score "
"FROM embeddings e "
"JOIN memory_chunks c ON c.id = e.chunk_id "
"JOIN memory_items i ON i.id = c.item_id "
"WHERE i.namespace_id = :ns::uuid "
" AND i.deleted_at IS NULL "
"ORDER BY e.vector <-> :vector "
"LIMIT :limit;"));
query.bindValue(QStringLiteral(":vector"), QString::fromStdString(toPgVectorLiteral(embedding)));
query.bindValue(QStringLiteral(":ns"), QString::fromStdString(namespaceId));
query.bindValue(QStringLiteral(":limit"), limit);
if (!query.exec()) {
throw std::runtime_error(query.lastError().text().toStdString());
}
while (query.next()) {
std::pair<std::string, float> entry;
entry.first = query.value(0).toString().toStdString();
entry.second = static_cast<float>(query.value(1).toDouble());
results.push_back(std::move(entry));
}
return results;
}
std::optional<ItemRow> PgDal::getItemById(const std::string& id) const {
if (!useInMemory_ && hasDatabase()) {
return sqlGetItemById(id);
}
auto it = items_.find(id);
if (it == items_.end()) {
return std::nullopt;
}
return it->second;
}
std::optional<ItemRow> PgDal::sqlGetItemById(const std::string& id) const {
QSqlDatabase db = database();
QSqlQuery query(db);
query.prepare(QStringLiteral(
"SELECT id::text, namespace_id::text, key, content, metadata::text, text, tags::text[], revision, created_at, expires_at "
"FROM memory_items WHERE id = :id::uuid"));
query.bindValue(QStringLiteral(":id"), QString::fromStdString(id));
if (!query.exec()) {
throw std::runtime_error(query.lastError().text().toStdString());
}
if (!query.next()) {
return std::nullopt;
}
ItemRow row;
row.id = query.value(0).toString().toStdString();
row.namespace_id = query.value(1).toString().toStdString();
if (!query.value(2).isNull()) {
row.key = query.value(2).toString().toStdString();
}
row.content_json = query.value(3).toString().toStdString();
row.metadata_json = query.value(4).toString().toStdString();
if (!query.value(5).isNull()) {
row.text = query.value(5).toString().toStdString();
}
row.tags = parsePgTextArray(query.value(6).toString());
row.revision = query.value(7).toInt();
row.created_at = fromSqlTimestamp(query.value(8));
row.expires_at = fromSqlTimestampOptional(query.value(9));
return row;
}
std::vector<ItemRow> PgDal::fetchContext(const std::string& namespaceId,
const std::optional<std::string>& key,
const std::vector<std::string>& tags,
const std::optional<std::string>& sinceIso,
int limit) {
if (!useInMemory_ && hasDatabase()) {
return sqlFetchContext(namespaceId, key, tags, sinceIso, limit);
}
std::vector<ItemRow> results;
auto bucketIt = itemsByNamespace_.find(namespaceId);
if (bucketIt == itemsByNamespace_.end()) {
return results;
}
const auto sinceTp = parseIsoTimestamp(sinceIso);
for (const auto& itemId : bucketIt->second) {
auto it = items_.find(itemId);
if (it == items_.end()) continue;
const ItemRow& row = it->second;
if (key && (!row.key || *row.key != *key)) {
continue;
}
if (!tags.empty() && !tagsMatch(row.tags, tags)) {
continue;
}
if (sinceTp && row.created_at < *sinceTp) {
continue;
}
results.push_back(row);
}
std::sort(results.begin(), results.end(), [](const ItemRow& a, const ItemRow& b) {
return a.created_at > b.created_at;
});
if (static_cast<int>(results.size()) > limit) {
results.resize(static_cast<std::size_t>(limit));
}
return results;
}
std::vector<ItemRow> PgDal::sqlFetchContext(const std::string& namespaceId,
const std::optional<std::string>& key,
const std::vector<std::string>& tags,
const std::optional<std::string>& sinceIso,
int limit) const {
QSqlDatabase db = database();
QString queryStr = QStringLiteral(
"SELECT id::text, namespace_id::text, key, content, metadata::text, text, tags::text[], revision, created_at, expires_at "
"FROM memory_items WHERE namespace_id = :ns::uuid AND deleted_at IS NULL");
if (key && !key->empty()) {
queryStr += QStringLiteral(" AND key = :key");
}
if (!tags.empty()) {
queryStr += QStringLiteral(" AND tags @> :tags::text[]");
}
if (sinceIso && !sinceIso->empty()) {
queryStr += QStringLiteral(" AND created_at >= :since");
}
queryStr += QStringLiteral(" ORDER BY created_at DESC LIMIT :limit");
QSqlQuery query(db);
query.prepare(queryStr);
query.bindValue(QStringLiteral(":ns"), QString::fromStdString(namespaceId));
query.bindValue(QStringLiteral(":limit"), limit);
if (key && !key->empty()) {
query.bindValue(QStringLiteral(":key"), QString::fromStdString(*key));
}
if (!tags.empty()) {
query.bindValue(QStringLiteral(":tags"), QString::fromStdString(toPgArrayLiteral(tags)));
}
if (sinceIso && !sinceIso->empty()) {
const auto sinceTp = parseIsoTimestamp(sinceIso);
query.bindValue(QStringLiteral(":since"), toSqlTimestamp(sinceTp));
}
if (!query.exec()) {
throw std::runtime_error(query.lastError().text().toStdString());
}
std::vector<ItemRow> rows;
while (query.next()) {
ItemRow row;
row.id = query.value(0).toString().toStdString();
row.namespace_id = query.value(1).toString().toStdString();
if (!query.value(2).isNull()) {
row.key = query.value(2).toString().toStdString();
}
row.content_json = query.value(3).toString().toStdString();
row.metadata_json = query.value(4).toString().toStdString();
if (!query.value(5).isNull()) {
row.text = query.value(5).toString().toStdString();
}
row.tags = parsePgTextArray(query.value(6).toString());
row.revision = query.value(7).toInt();
row.created_at = fromSqlTimestamp(query.value(8));
row.expires_at = fromSqlTimestampOptional(query.value(9));
rows.push_back(std::move(row));
}
return rows;
}
std::pair<std::string, int> PgDal::upsertItem(
const std::string& namespace_id,
const std::optional<std::string>& key,
const std::string& content,
const std::string& metadata_json,
const std::vector<std::string>& tags) {
ItemRow row;
row.namespace_id = namespace_id;
row.key = key;
row.content_json = content;
row.metadata_json = metadata_json.empty() ? "{}" : metadata_json;
if (!content.empty()) {
row.text = content;
}
row.tags = tags;
row.created_at = std::chrono::system_clock::now();
if (!useInMemory_ && hasDatabase()) {
return sqlUpsertItem(row);
}
const std::string id = upsertItem(row);
const auto stored = items_.find(id);
const int revision = stored != items_.end() ? stored->second.revision : 1;
return {id, revision};
}
std::string PgDal::insertChunk(const std::string& item_id,
int seq,
const std::string& content) {
ChunkRow row;
row.item_id = item_id;
row.ord = seq;
row.text = content;
auto ids = upsertChunks(std::vector<ChunkRow>{row});
return ids.empty() ? std::string() : ids.front();
}
void PgDal::insertEmbedding(const Embedding& embedding) {
EmbeddingRow row;
row.chunk_id = embedding.chunk_id;
row.model = embedding.model;
row.dim = embedding.dim;
row.vector = embedding.vector;
upsertEmbeddings(std::vector<EmbeddingRow>{row});
}
std::vector<std::string> PgDal::hybridSearch(const std::vector<float>& query_vec,
const std::string& model,
const std::string& namespace_id,
const std::string& query_text,
int k) {
(void)model;
if (!useInMemory_ && hasDatabase()) {
return sqlHybridSearch(query_vec, model, namespace_id, query_text, k);
}
std::vector<std::string> results;
auto textMatches = searchText(namespace_id, query_text, k);
for (const auto& item : textMatches) {
results.push_back(item.id);
if (static_cast<int>(results.size()) >= k) {
return results;
}
}
auto vectorMatches = searchVector(namespace_id, query_vec, k);
for (const auto& pair : vectorMatches) {
if (!idsContains(results, pair.first)) {
results.push_back(pair.first);
}
if (static_cast<int>(results.size()) >= k) break;
}
return results;
}
std::vector<std::string> PgDal::sqlHybridSearch(const std::vector<float>& query_vec,
const std::string& model,
const std::string& namespace_id,
const std::string& query_text,
int k) {
(void)model;
std::unordered_set<std::string> seen;
std::vector<std::string> results;
auto textMatches = sqlSearchText(namespace_id, query_text, k);
for (const auto& item : textMatches) {
results.push_back(item.id);
seen.insert(item.id);
if (static_cast<int>(results.size()) >= k) {
return results;
}
}
if (!query_vec.empty()) {
auto vectorMatches = sqlSearchVector(namespace_id, query_vec, k);
for (const auto& pair : vectorMatches) {
if (seen.count(pair.first)) continue;
results.push_back(pair.first);
seen.insert(pair.first);
if (static_cast<int>(results.size()) >= k) break;
}
}
return results;
}
std::string PgDal::allocateId(std::size_t& counter, const std::string& prefix) {
return prefix + std::to_string(counter++);
}
std::string PgDal::toLower(const std::string& value) {
std::string lowered = value;
std::transform(lowered.begin(), lowered.end(), lowered.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
return lowered;
}
std::string PgDal::escapePgArrayElement(const std::string& value) {
std::string escaped;
escaped.reserve(value.size());
for (char c : value) {
if (c == '"' || c == '\\') {
escaped.push_back('\\');
}
escaped.push_back(c);
}
return escaped;
}
std::string PgDal::toPgArrayLiteral(const std::vector<std::string>& values) {
if (values.empty()) {
return "{}";
}
std::string out = "{";
for (std::size_t i = 0; i < values.size(); ++i) {
if (i) out += ',';
out += '"';
out += escapePgArrayElement(values[i]);
out += '"';
}
out += "}";
return out;
}
std::string PgDal::toPgVectorLiteral(const std::vector<float>& values) {
if (values.empty()) {
return "[]";
}
std::string out = "[";
for (std::size_t i = 0; i < values.size(); ++i) {
if (i) out += ',';
out += std::to_string(values[i]);
}
out += "]";
return out;
}
std::vector<std::string> PgDal::parsePgTextArray(const QString& value) {
std::vector<std::string> tags;
QString trimmed = value.trimmed();
if (!trimmed.startsWith(QLatin1Char('{')) || !trimmed.endsWith(QLatin1Char('}'))) {
return tags;
}
trimmed = trimmed.mid(1, trimmed.size() - 2);
QString current;
bool inQuotes = false;
bool escape = false;
for (QChar ch : trimmed) {
if (escape) {
current.append(ch);
escape = false;
continue;
}
if (ch == QLatin1Char('\\')) {
escape = true;
continue;
}
if (ch == QLatin1Char('"')) {
inQuotes = !inQuotes;
continue;
}
if (!inQuotes && ch == QLatin1Char(',')) {
tags.push_back(current.toStdString());
current.clear();
continue;
}
current.append(ch);
}
if (!current.isEmpty()) {
tags.push_back(current.toStdString());
}
return tags;
}
} // namespace kom