1105 lines
34 KiB
C++
1105 lines
34 KiB
C++
#include "PgDal.hpp"
|
|
|
|
#include <QSqlDatabase>
|
|
#include <QSqlDriver>
|
|
#include <QSqlError>
|
|
#include <QSqlQuery>
|
|
#include <QSqlRecord>
|
|
#include <QStringList>
|
|
#include <QMetaType>
|
|
#include <QDateTime>
|
|
#include <QTimeZone>
|
|
#include <QUrl>
|
|
#include <QUrlQuery>
|
|
#include <QVariant>
|
|
|
|
#include <algorithm>
|
|
#include <cctype>
|
|
#include <numeric>
|
|
#include <stdexcept>
|
|
|
|
namespace kom {
|
|
|
|
namespace {
|
|
|
|
bool idsContains(const std::vector<std::string>& ids, const std::string& value) {
|
|
return std::find(ids.begin(), ids.end(), value) != ids.end();
|
|
}
|
|
|
|
|
|
} // namespace
|
|
|
|
PgDal::PgDal() = default;
|
|
|
|
PgDal::~PgDal() {
|
|
closeDatabase();
|
|
}
|
|
|
|
bool PgDal::isStubDsn(const std::string& dsn) {
|
|
return dsn.empty() || dsn.rfind("stub://", 0) == 0;
|
|
}
|
|
|
|
namespace {
|
|
|
|
QVariant toSqlTimestamp(const std::chrono::system_clock::time_point& tp) {
|
|
if (tp.time_since_epoch().count() == 0) {
|
|
return QVariant(QMetaType(QMetaType::QDateTime));
|
|
}
|
|
const auto msecs = std::chrono::duration_cast<std::chrono::milliseconds>(tp.time_since_epoch()).count();
|
|
return QDateTime::fromMSecsSinceEpoch(msecs, QTimeZone::UTC);
|
|
}
|
|
|
|
QVariant toSqlTimestamp(const std::optional<std::chrono::system_clock::time_point>& tp) {
|
|
if (!tp) {
|
|
return QVariant(QMetaType(QMetaType::QDateTime));
|
|
}
|
|
return toSqlTimestamp(*tp);
|
|
}
|
|
|
|
std::chrono::system_clock::time_point fromSqlTimestamp(const QVariant& value) {
|
|
if (!value.isValid() || value.isNull()) {
|
|
return std::chrono::system_clock::time_point{};
|
|
}
|
|
QDateTime dt = value.toDateTime();
|
|
if (!dt.isValid()) {
|
|
dt = QDateTime::fromString(value.toString(), Qt::ISODateWithMs);
|
|
}
|
|
if (!dt.isValid()) {
|
|
return std::chrono::system_clock::time_point{};
|
|
}
|
|
dt = dt.toTimeZone(QTimeZone::UTC);
|
|
return std::chrono::system_clock::time_point(std::chrono::milliseconds(dt.toMSecsSinceEpoch()));
|
|
}
|
|
|
|
std::optional<std::chrono::system_clock::time_point> fromSqlTimestampOptional(const QVariant& value) {
|
|
if (!value.isValid() || value.isNull()) {
|
|
return std::nullopt;
|
|
}
|
|
return fromSqlTimestamp(value);
|
|
}
|
|
|
|
std::optional<std::chrono::system_clock::time_point> parseIsoTimestamp(const std::optional<std::string>& iso) {
|
|
if (!iso || iso->empty()) {
|
|
return std::nullopt;
|
|
}
|
|
QDateTime dt = QDateTime::fromString(QString::fromStdString(*iso), Qt::ISODateWithMs);
|
|
if (!dt.isValid()) {
|
|
dt = QDateTime::fromString(QString::fromStdString(*iso), Qt::ISODate);
|
|
}
|
|
if (!dt.isValid()) {
|
|
return std::nullopt;
|
|
}
|
|
dt = dt.toTimeZone(QTimeZone::UTC);
|
|
return std::chrono::system_clock::time_point(std::chrono::milliseconds(dt.toMSecsSinceEpoch()));
|
|
}
|
|
|
|
bool tagsMatch(const std::vector<std::string>& rowTags, const std::vector<std::string>& filterTags) {
|
|
for (const auto& tag : filterTags) {
|
|
if (std::find(rowTags.begin(), rowTags.end(), tag) == rowTags.end()) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
bool PgDal::connect(const std::string& dsn) {
|
|
dsn_ = dsn;
|
|
connected_ = true;
|
|
transactionActive_ = false;
|
|
|
|
if (isStubDsn(dsn)) {
|
|
closeDatabase();
|
|
useInMemory_ = true;
|
|
namespacesByName_.clear();
|
|
namespacesById_.clear();
|
|
items_.clear();
|
|
itemsByNamespace_.clear();
|
|
chunks_.clear();
|
|
chunksByItem_.clear();
|
|
embeddings_.clear();
|
|
nextNamespaceId_ = 1;
|
|
nextItemId_ = 1;
|
|
nextChunkId_ = 1;
|
|
nextEmbeddingId_ = 1;
|
|
return true;
|
|
}
|
|
|
|
if (openDatabase(dsn)) {
|
|
useInMemory_ = false;
|
|
return true;
|
|
}
|
|
|
|
useInMemory_ = true;
|
|
return false;
|
|
}
|
|
|
|
bool PgDal::begin() {
|
|
if (!connected_ || !hasDatabase()) {
|
|
return false;
|
|
}
|
|
|
|
if (transactionActive_) {
|
|
return true;
|
|
}
|
|
|
|
QSqlDatabase db = database();
|
|
if (!db.transaction()) {
|
|
throw std::runtime_error(db.lastError().text().toStdString());
|
|
}
|
|
transactionActive_ = true;
|
|
return true;
|
|
}
|
|
|
|
void PgDal::commit() {
|
|
if (!transactionActive_) {
|
|
return;
|
|
}
|
|
QSqlDatabase db = database();
|
|
if (!db.commit()) {
|
|
throw std::runtime_error(db.lastError().text().toStdString());
|
|
}
|
|
transactionActive_ = false;
|
|
}
|
|
|
|
void PgDal::rollback() {
|
|
if (!transactionActive_) {
|
|
return;
|
|
}
|
|
QSqlDatabase db = database();
|
|
db.rollback();
|
|
transactionActive_ = false;
|
|
}
|
|
|
|
bool PgDal::hasDatabase() const {
|
|
return !connectionName_.isEmpty() && QSqlDatabase::contains(connectionName_);
|
|
}
|
|
|
|
QSqlDatabase PgDal::database() const {
|
|
if (!hasDatabase()) {
|
|
return QSqlDatabase();
|
|
}
|
|
return QSqlDatabase::database(connectionName_);
|
|
}
|
|
|
|
void PgDal::closeDatabase() {
|
|
if (connectionName_.isEmpty()) {
|
|
return;
|
|
}
|
|
{
|
|
QSqlDatabase db = QSqlDatabase::database(connectionName_, false);
|
|
if (db.isValid()) {
|
|
if (transactionActive_) {
|
|
db.rollback();
|
|
transactionActive_ = false;
|
|
}
|
|
db.close();
|
|
}
|
|
}
|
|
QSqlDatabase::removeDatabase(connectionName_);
|
|
connectionName_.clear();
|
|
}
|
|
|
|
PgDal::ConnectionConfig PgDal::parseDsn(const std::string& dsn) const {
|
|
ConnectionConfig cfg;
|
|
QUrl url(QString::fromStdString(dsn));
|
|
if (url.scheme().isEmpty()) {
|
|
url = QUrl(QStringLiteral("postgresql://") + QString::fromStdString(dsn));
|
|
}
|
|
|
|
cfg.dbname = url.path().isEmpty() ? QStringLiteral("kompanion")
|
|
: url.path().mid(1);
|
|
cfg.host = url.host().isEmpty() ? QStringLiteral("localhost")
|
|
: url.host();
|
|
cfg.port = url.port(5432);
|
|
cfg.user = url.userName();
|
|
cfg.password = url.password();
|
|
|
|
QUrlQuery query(url);
|
|
if (query.hasQueryItem(QStringLiteral("host"))) {
|
|
const QString hostValue = query.queryItemValue(QStringLiteral("host"));
|
|
if (hostValue.startsWith(QStringLiteral("/"))) {
|
|
cfg.useSocket = true;
|
|
cfg.socketPath = hostValue;
|
|
}
|
|
}
|
|
|
|
if (cfg.useSocket && cfg.socketPath.isEmpty()) {
|
|
cfg.socketPath = QStringLiteral("/var/run/postgresql");
|
|
}
|
|
|
|
QStringList optionPairs;
|
|
const auto queryItems = query.queryItems();
|
|
for (const auto& item : queryItems) {
|
|
if (item.first == QStringLiteral("host")) continue;
|
|
optionPairs << QStringLiteral("%1=%2").arg(item.first, item.second);
|
|
}
|
|
cfg.options = optionPairs.join(QLatin1Char(';'));
|
|
|
|
return cfg;
|
|
}
|
|
|
|
bool PgDal::openDatabase(const std::string& dsn) {
|
|
if (!QSqlDatabase::isDriverAvailable(QStringLiteral("QPSQL"))) {
|
|
return false;
|
|
}
|
|
|
|
closeDatabase();
|
|
const ConnectionConfig cfg = parseDsn(dsn);
|
|
|
|
connectionName_ = QStringLiteral("kom_dal_%1").arg(reinterpret_cast<quintptr>(this), 0, 16);
|
|
QSqlDatabase db = QSqlDatabase::addDatabase(QStringLiteral("QPSQL"), connectionName_);
|
|
db.setDatabaseName(cfg.dbname);
|
|
if (!cfg.user.isEmpty()) {
|
|
db.setUserName(cfg.user);
|
|
}
|
|
if (!cfg.password.isEmpty()) {
|
|
db.setPassword(cfg.password);
|
|
}
|
|
if (cfg.useSocket) {
|
|
db.setHostName(cfg.socketPath);
|
|
} else {
|
|
db.setHostName(cfg.host);
|
|
}
|
|
if (cfg.port > 0) {
|
|
db.setPort(cfg.port);
|
|
}
|
|
if (!cfg.options.isEmpty()) {
|
|
db.setConnectOptions(cfg.options);
|
|
}
|
|
|
|
if (!db.open()) {
|
|
const std::string err = db.lastError().text().toStdString();
|
|
closeDatabase();
|
|
throw std::runtime_error("PgDal: failed to open database: " + err);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
std::optional<NamespaceRow> PgDal::ensureNamespace(const std::string& name) {
|
|
if (!connected_) {
|
|
return std::nullopt;
|
|
}
|
|
if (auto existing = findNamespace(name)) {
|
|
return existing;
|
|
}
|
|
return createNamespaceWithSecret(name).first;
|
|
}
|
|
|
|
std::optional<NamespaceRow> PgDal::findNamespace(const std::string& name) const {
|
|
if (!useInMemory_ && hasDatabase()) {
|
|
return sqlFindNamespace(name);
|
|
}
|
|
|
|
auto it = namespacesByName_.find(name);
|
|
if (it == namespacesByName_.end()) {
|
|
return std::nullopt;
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
|
|
std::optional<NamespaceRow> PgDal::sqlFindNamespace(const std::string& name) const {
|
|
QSqlDatabase db = database();
|
|
QSqlQuery query(db);
|
|
query.prepare(QStringLiteral(
|
|
"SELECT id::text, name FROM namespaces WHERE name = :name"));
|
|
query.bindValue(QStringLiteral(":name"), QString::fromStdString(name));
|
|
if (!query.exec()) {
|
|
throw std::runtime_error(query.lastError().text().toStdString());
|
|
}
|
|
if (!query.next()) {
|
|
return std::nullopt;
|
|
}
|
|
NamespaceRow row;
|
|
row.id = query.value(0).toString().toStdString();
|
|
row.name = query.value(1).toString().toStdString();
|
|
return row;
|
|
}
|
|
|
|
std::pair<NamespaceRow, std::string> PgDal::createNamespaceWithSecret(const std::string& name) {
|
|
if (!connected_) {
|
|
throw std::runtime_error("PgDal not connected");
|
|
}
|
|
if (!useInMemory_ && hasDatabase()) {
|
|
return sqlCreateNamespaceWithSecret(name);
|
|
}
|
|
|
|
// In-memory implementation
|
|
auto it = namespacesByName_.find(name);
|
|
if (it != namespacesByName_.end()) {
|
|
// For in-memory, we don't have secrets, so we can't return one.
|
|
// This path should ideally not be taken in production.
|
|
return {it->second, ""};
|
|
}
|
|
|
|
NamespaceRow row;
|
|
row.id = allocateId(nextNamespaceId_, "ns_");
|
|
row.name = name;
|
|
namespacesByName_[name] = row;
|
|
namespacesById_[row.id] = row;
|
|
|
|
// Secrets are not supported in-memory for now
|
|
return {row, ""};
|
|
}
|
|
|
|
std::optional<AuthSecret> PgDal::findSecretByNamespaceId(const std::string& namespaceId) const {
|
|
if (!useInMemory_ && hasDatabase()) {
|
|
return sqlFindSecretByNamespaceId(namespaceId);
|
|
}
|
|
// In-memory implementation does not support secrets
|
|
return std::nullopt;
|
|
}
|
|
|
|
std::pair<NamespaceRow, std::string> PgDal::sqlCreateNamespaceWithSecret(const std::string& name) {
|
|
QSqlDatabase db = database();
|
|
QSqlQuery query(db);
|
|
|
|
// 1. Create the namespace
|
|
query.prepare(QStringLiteral(
|
|
"INSERT INTO namespaces (name) VALUES (:name) "
|
|
"ON CONFLICT (name) DO UPDATE SET name = EXCLUDED.name "
|
|
"RETURNING id::text, name;"));
|
|
query.bindValue(QStringLiteral(":name"), QString::fromStdString(name));
|
|
if (!query.exec() || !query.next()) {
|
|
throw std::runtime_error(query.lastError().text().toStdString());
|
|
}
|
|
NamespaceRow row;
|
|
row.id = query.value(0).toString().toStdString();
|
|
row.name = query.value(1).toString().toStdString();
|
|
|
|
// 2. Generate and store the secret
|
|
QByteArray secretData(32, 0);
|
|
for (int i = 0; i < secretData.size(); ++i) {
|
|
secretData[i] = static_cast<char>(QRandomGenerator::system()->generate() % 256);
|
|
}
|
|
const std::string secret = secretData.toHex().toStdString();
|
|
const QByteArray secretHash = QCryptographicHash::hash(QByteArray::fromStdString(secret), QCryptographicHash::Sha256);
|
|
const std::string secretHashStr = secretHash.toHex().toStdString();
|
|
|
|
sqlInsertSecret(row.id, secretHashStr);
|
|
|
|
return {row, secret};
|
|
}
|
|
|
|
void PgDal::sqlInsertSecret(const std::string& namespaceId, const std::string& secretHash) {
|
|
QSqlDatabase db = database();
|
|
QSqlQuery query(db);
|
|
query.prepare(QStringLiteral(
|
|
"INSERT INTO auth_secrets (namespace_id, secret_hash) "
|
|
"VALUES (:namespace_id::uuid, :secret_hash) "
|
|
"ON CONFLICT (namespace_id) DO UPDATE SET secret_hash = EXCLUDED.secret_hash;"));
|
|
query.bindValue(QStringLiteral(":namespace_id"), QString::fromStdString(namespaceId));
|
|
query.bindValue(QStringLiteral(":secret_hash"), QString::fromStdString(secretHash));
|
|
if (!query.exec()) {
|
|
throw std::runtime_error(query.lastError().text().toStdString());
|
|
}
|
|
}
|
|
|
|
std::optional<AuthSecret> PgDal::sqlFindSecretByNamespaceId(const std::string& namespaceId) const {
|
|
QSqlDatabase db = database();
|
|
QSqlQuery query(db);
|
|
query.prepare(QStringLiteral(
|
|
"SELECT secret_hash FROM auth_secrets WHERE namespace_id = :namespace_id::uuid"));
|
|
query.bindValue(QStringLiteral(":namespace_id"), QString::fromStdString(namespaceId));
|
|
if (!query.exec()) {
|
|
throw std::runtime_error(query.lastError().text().toStdString());
|
|
}
|
|
if (!query.next()) {
|
|
return std::nullopt;
|
|
}
|
|
AuthSecret secret;
|
|
secret.secret_hash = query.value(0).toString().toStdString();
|
|
return secret;
|
|
}
|
|
|
|
|
|
std::string PgDal::upsertItem(const ItemRow& row) {
|
|
if (!connected_) {
|
|
throw std::runtime_error("PgDal not connected");
|
|
}
|
|
if (!useInMemory_ && hasDatabase()) {
|
|
return sqlUpsertItem(row).first;
|
|
}
|
|
|
|
ItemRow stored = row;
|
|
if (stored.id.empty()) {
|
|
stored.id = allocateId(nextItemId_, "item_");
|
|
}
|
|
|
|
auto existing = items_.find(stored.id);
|
|
if (existing != items_.end()) {
|
|
stored.revision = existing->second.revision + 1;
|
|
}
|
|
|
|
items_[stored.id] = stored;
|
|
|
|
auto& bucket = itemsByNamespace_[stored.namespace_id];
|
|
if (!idsContains(bucket, stored.id)) {
|
|
bucket.push_back(stored.id);
|
|
}
|
|
|
|
return stored.id;
|
|
}
|
|
|
|
std::pair<std::string, int> PgDal::sqlUpsertItem(const ItemRow& row) {
|
|
QSqlDatabase db = database();
|
|
QSqlQuery query(db);
|
|
query.prepare(QStringLiteral(
|
|
"INSERT INTO memory_items (id, namespace_id, key, content, metadata, tags, text, created_at, expires_at) "
|
|
"VALUES (COALESCE(NULLIF(:id, '')::uuid, gen_random_uuid()), "
|
|
" :namespace_id::uuid, :key, :content, :metadata::jsonb, :tags::text[], :text, "
|
|
" COALESCE(:created_at, now()), :expires_at) "
|
|
"ON CONFLICT (id) DO UPDATE SET "
|
|
" key = EXCLUDED.key, content = EXCLUDED.content, metadata = EXCLUDED.metadata, "
|
|
" tags = EXCLUDED.tags, text = EXCLUDED.text, updated_at = now(), "
|
|
" expires_at = EXCLUDED.expires_at, "
|
|
" revision = memory_items.revision + 1 "
|
|
"RETURNING id::text, revision;"));
|
|
|
|
query.bindValue(QStringLiteral(":id"), QString::fromStdString(row.id));
|
|
query.bindValue(QStringLiteral(":namespace_id"), QString::fromStdString(row.namespace_id));
|
|
if (row.key) {
|
|
query.bindValue(QStringLiteral(":key"), QString::fromStdString(*row.key));
|
|
} else {
|
|
query.bindValue(QStringLiteral(":key"), QVariant(QMetaType(QMetaType::QString)));
|
|
}
|
|
query.bindValue(QStringLiteral(":content"), QString::fromStdString(row.content_json));
|
|
query.bindValue(QStringLiteral(":metadata"), QString::fromStdString(row.metadata_json));
|
|
query.bindValue(QStringLiteral(":tags"), QString::fromStdString(toPgArrayLiteral(row.tags)));
|
|
if (row.text) {
|
|
query.bindValue(QStringLiteral(":text"), QString::fromStdString(*row.text));
|
|
} else {
|
|
query.bindValue(QStringLiteral(":text"), QVariant(QMetaType(QMetaType::QString)));
|
|
}
|
|
query.bindValue(QStringLiteral(":created_at"), toSqlTimestamp(row.created_at));
|
|
query.bindValue(QStringLiteral(":expires_at"), toSqlTimestamp(row.expires_at));
|
|
|
|
if (!query.exec() || !query.next()) {
|
|
throw std::runtime_error(query.lastError().text().toStdString());
|
|
}
|
|
|
|
std::pair<std::string, int> result;
|
|
result.first = query.value(0).toString().toStdString();
|
|
result.second = query.value(1).toInt();
|
|
return result;
|
|
}
|
|
|
|
std::vector<std::string> PgDal::upsertChunks(const std::vector<ChunkRow>& chunks) {
|
|
if (!connected_) {
|
|
throw std::runtime_error("PgDal not connected");
|
|
}
|
|
if (!useInMemory_ && hasDatabase()) {
|
|
return sqlUpsertChunks(chunks);
|
|
}
|
|
|
|
std::vector<std::string> ids;
|
|
ids.reserve(chunks.size());
|
|
|
|
for (const auto& input : chunks) {
|
|
ChunkRow stored = input;
|
|
if (stored.item_id.empty()) {
|
|
continue;
|
|
}
|
|
if (stored.id.empty()) {
|
|
stored.id = allocateId(nextChunkId_, "chunk_");
|
|
}
|
|
|
|
chunks_[stored.id] = stored;
|
|
auto& bucket = chunksByItem_[stored.item_id];
|
|
if (!idsContains(bucket, stored.id)) {
|
|
bucket.push_back(stored.id);
|
|
}
|
|
ids.push_back(stored.id);
|
|
}
|
|
|
|
return ids;
|
|
}
|
|
|
|
std::vector<std::string> PgDal::sqlUpsertChunks(const std::vector<ChunkRow>& chunks) {
|
|
std::vector<std::string> ids;
|
|
ids.reserve(chunks.size());
|
|
|
|
QSqlDatabase db = database();
|
|
QSqlQuery query(db);
|
|
query.prepare(QStringLiteral(
|
|
"INSERT INTO memory_chunks (id, item_id, seq, content) "
|
|
"VALUES (COALESCE(NULLIF(:id, '')::uuid, gen_random_uuid()), "
|
|
" :item_id::uuid, :seq, :content) "
|
|
"ON CONFLICT (id) DO UPDATE SET seq = EXCLUDED.seq, content = EXCLUDED.content "
|
|
"RETURNING id::text;"));
|
|
|
|
for (const auto& chunk : chunks) {
|
|
query.bindValue(QStringLiteral(":id"), QString::fromStdString(chunk.id));
|
|
query.bindValue(QStringLiteral(":item_id"), QString::fromStdString(chunk.item_id));
|
|
query.bindValue(QStringLiteral(":seq"), chunk.ord);
|
|
query.bindValue(QStringLiteral(":content"), QString::fromStdString(chunk.text));
|
|
|
|
if (!query.exec() || !query.next()) {
|
|
throw std::runtime_error(query.lastError().text().toStdString());
|
|
}
|
|
ids.push_back(query.value(0).toString().toStdString());
|
|
query.finish();
|
|
}
|
|
return ids;
|
|
}
|
|
|
|
void PgDal::upsertEmbeddings(const std::vector<EmbeddingRow>& embeddings) {
|
|
if (!connected_) {
|
|
throw std::runtime_error("PgDal not connected");
|
|
}
|
|
if (!useInMemory_ && hasDatabase()) {
|
|
sqlUpsertEmbeddings(embeddings);
|
|
return;
|
|
}
|
|
|
|
for (const auto& input : embeddings) {
|
|
if (input.chunk_id.empty()) {
|
|
continue;
|
|
}
|
|
EmbeddingRow stored = input;
|
|
if (stored.id.empty()) {
|
|
stored.id = allocateId(nextEmbeddingId_, "emb_");
|
|
}
|
|
embeddings_[stored.chunk_id] = stored;
|
|
}
|
|
}
|
|
|
|
void PgDal::sqlUpsertEmbeddings(const std::vector<EmbeddingRow>& embeddings) {
|
|
QSqlDatabase db = database();
|
|
QSqlQuery query(db);
|
|
query.prepare(QStringLiteral(
|
|
"INSERT INTO embeddings (id, chunk_id, model, dim, vector, normalized) "
|
|
"VALUES (COALESCE(NULLIF(:id, '')::uuid, gen_random_uuid()), "
|
|
" :chunk_id::uuid, :model, :dim, :vector::vector, FALSE) "
|
|
"ON CONFLICT (chunk_id, model) DO UPDATE SET "
|
|
" dim = EXCLUDED.dim, vector = EXCLUDED.vector, normalized = EXCLUDED.normalized "
|
|
"RETURNING id::text;"));
|
|
|
|
for (const auto& emb : embeddings) {
|
|
query.bindValue(QStringLiteral(":id"), QString::fromStdString(emb.id));
|
|
query.bindValue(QStringLiteral(":chunk_id"), QString::fromStdString(emb.chunk_id));
|
|
query.bindValue(QStringLiteral(":model"), QString::fromStdString(emb.model));
|
|
query.bindValue(QStringLiteral(":dim"), emb.dim);
|
|
query.bindValue(QStringLiteral(":vector"), QString::fromStdString(toPgVectorLiteral(emb.vector)));
|
|
|
|
if (!query.exec()) {
|
|
throw std::runtime_error(query.lastError().text().toStdString());
|
|
}
|
|
query.finish();
|
|
}
|
|
}
|
|
|
|
std::vector<ItemRow> PgDal::searchText(const std::string& namespaceId,
|
|
const std::string& queryText,
|
|
int limit) {
|
|
if (!useInMemory_ && hasDatabase()) {
|
|
return sqlSearchText(namespaceId, queryText, limit);
|
|
}
|
|
|
|
std::vector<ItemRow> results;
|
|
if (!connected_) return results;
|
|
auto bucketIt = itemsByNamespace_.find(namespaceId);
|
|
if (bucketIt == itemsByNamespace_.end()) return results;
|
|
|
|
const std::string loweredQuery = toLower(queryText);
|
|
|
|
for (const auto& itemId : bucketIt->second) {
|
|
auto itemIt = items_.find(itemId);
|
|
if (itemIt == items_.end()) continue;
|
|
|
|
if (!loweredQuery.empty()) {
|
|
const std::string loweredText = toLower(itemIt->second.text.value_or(std::string()));
|
|
if (loweredText.find(loweredQuery) == std::string::npos) {
|
|
continue;
|
|
}
|
|
}
|
|
|
|
results.push_back(itemIt->second);
|
|
if (static_cast<int>(results.size()) >= limit) break;
|
|
}
|
|
|
|
return results;
|
|
}
|
|
|
|
std::vector<ItemRow> PgDal::sqlSearchText(const std::string& namespaceId,
|
|
const std::string& queryText,
|
|
int limit) const {
|
|
QSqlDatabase db = database();
|
|
QSqlQuery query(db);
|
|
query.prepare(QStringLiteral(
|
|
"SELECT id::text, namespace_id::text, key, content, metadata::text, text, tags::text[], revision, created_at, expires_at "
|
|
"FROM memory_items "
|
|
"WHERE namespace_id = :ns::uuid "
|
|
" AND deleted_at IS NULL "
|
|
" AND (:query = '' OR text ILIKE '%' || :query || '%') "
|
|
"ORDER BY updated_at DESC "
|
|
"LIMIT :limit;"));
|
|
query.bindValue(QStringLiteral(":ns"), QString::fromStdString(namespaceId));
|
|
query.bindValue(QStringLiteral(":query"), QString::fromStdString(queryText));
|
|
query.bindValue(QStringLiteral(":limit"), limit);
|
|
|
|
if (!query.exec()) {
|
|
throw std::runtime_error(query.lastError().text().toStdString());
|
|
}
|
|
|
|
std::vector<ItemRow> results;
|
|
while (query.next()) {
|
|
ItemRow row;
|
|
row.id = query.value(0).toString().toStdString();
|
|
row.namespace_id = query.value(1).toString().toStdString();
|
|
if (!query.value(2).isNull()) {
|
|
row.key = query.value(2).toString().toStdString();
|
|
}
|
|
row.content_json = query.value(3).toString().toStdString();
|
|
row.metadata_json = query.value(4).toString().toStdString();
|
|
if (!query.value(5).isNull()) {
|
|
row.text = query.value(5).toString().toStdString();
|
|
}
|
|
row.tags = parsePgTextArray(query.value(6).toString());
|
|
row.revision = query.value(7).toInt();
|
|
row.created_at = fromSqlTimestamp(query.value(8));
|
|
row.expires_at = fromSqlTimestampOptional(query.value(9));
|
|
results.push_back(std::move(row));
|
|
}
|
|
return results;
|
|
}
|
|
|
|
std::vector<std::pair<std::string, float>> PgDal::searchVector(
|
|
const std::string& namespaceId,
|
|
const std::vector<float>& embedding,
|
|
int limit) {
|
|
if (!useInMemory_ && hasDatabase()) {
|
|
return sqlSearchVector(namespaceId, embedding, limit);
|
|
}
|
|
|
|
std::vector<std::pair<std::string, float>> scores;
|
|
if (!connected_ || embedding.empty()) return scores;
|
|
|
|
auto bucketIt = itemsByNamespace_.find(namespaceId);
|
|
if (bucketIt == itemsByNamespace_.end()) return scores;
|
|
|
|
for (const auto& itemId : bucketIt->second) {
|
|
auto chunkBucketIt = chunksByItem_.find(itemId);
|
|
if (chunkBucketIt == chunksByItem_.end()) continue;
|
|
|
|
float bestScore = -1.0f;
|
|
for (const auto& chunkId : chunkBucketIt->second) {
|
|
auto embIt = embeddings_.find(chunkId);
|
|
if (embIt == embeddings_.end()) continue;
|
|
const auto& storedVec = embIt->second.vector;
|
|
if (storedVec.size() != embedding.size() || storedVec.empty()) continue;
|
|
float dot = std::inner_product(storedVec.begin(), storedVec.end(), embedding.begin(), 0.0f);
|
|
if (dot > bestScore) {
|
|
bestScore = dot;
|
|
}
|
|
}
|
|
|
|
if (bestScore >= 0.0f) {
|
|
scores.emplace_back(itemId, bestScore);
|
|
}
|
|
}
|
|
|
|
std::sort(scores.begin(), scores.end(),
|
|
[](const auto& lhs, const auto& rhs) {
|
|
if (lhs.second == rhs.second) {
|
|
return lhs.first < rhs.first;
|
|
}
|
|
return lhs.second > rhs.second;
|
|
});
|
|
|
|
if (static_cast<int>(scores.size()) > limit) {
|
|
scores.resize(static_cast<std::size_t>(limit));
|
|
}
|
|
|
|
return scores;
|
|
}
|
|
|
|
std::vector<std::pair<std::string, float>> PgDal::sqlSearchVector(
|
|
const std::string& namespaceId,
|
|
const std::vector<float>& embedding,
|
|
int limit) const {
|
|
std::vector<std::pair<std::string, float>> results;
|
|
if (embedding.empty()) {
|
|
return results;
|
|
}
|
|
|
|
QSqlDatabase db = database();
|
|
QSqlQuery query(db);
|
|
query.prepare(QStringLiteral(
|
|
"SELECT i.id::text, 1 - (e.vector <=> :vector::vector) AS score "
|
|
"FROM embeddings e "
|
|
"JOIN memory_chunks c ON c.id = e.chunk_id "
|
|
"JOIN memory_items i ON i.id = c.item_id "
|
|
"WHERE i.namespace_id = :ns::uuid "
|
|
" AND i.deleted_at IS NULL "
|
|
"ORDER BY e.vector <-> :vector "
|
|
"LIMIT :limit;"));
|
|
query.bindValue(QStringLiteral(":vector"), QString::fromStdString(toPgVectorLiteral(embedding)));
|
|
query.bindValue(QStringLiteral(":ns"), QString::fromStdString(namespaceId));
|
|
query.bindValue(QStringLiteral(":limit"), limit);
|
|
|
|
if (!query.exec()) {
|
|
throw std::runtime_error(query.lastError().text().toStdString());
|
|
}
|
|
|
|
while (query.next()) {
|
|
std::pair<std::string, float> entry;
|
|
entry.first = query.value(0).toString().toStdString();
|
|
entry.second = static_cast<float>(query.value(1).toDouble());
|
|
results.push_back(std::move(entry));
|
|
}
|
|
return results;
|
|
}
|
|
|
|
std::optional<ItemRow> PgDal::getItemById(const std::string& id) const {
|
|
if (!useInMemory_ && hasDatabase()) {
|
|
return sqlGetItemById(id);
|
|
}
|
|
|
|
auto it = items_.find(id);
|
|
if (it == items_.end()) {
|
|
return std::nullopt;
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
std::optional<ItemRow> PgDal::sqlGetItemById(const std::string& id) const {
|
|
QSqlDatabase db = database();
|
|
QSqlQuery query(db);
|
|
query.prepare(QStringLiteral(
|
|
"SELECT id::text, namespace_id::text, key, content, metadata::text, text, tags::text[], revision, created_at, expires_at "
|
|
"FROM memory_items WHERE id = :id::uuid"));
|
|
query.bindValue(QStringLiteral(":id"), QString::fromStdString(id));
|
|
|
|
if (!query.exec()) {
|
|
throw std::runtime_error(query.lastError().text().toStdString());
|
|
}
|
|
if (!query.next()) {
|
|
return std::nullopt;
|
|
}
|
|
ItemRow row;
|
|
row.id = query.value(0).toString().toStdString();
|
|
row.namespace_id = query.value(1).toString().toStdString();
|
|
if (!query.value(2).isNull()) {
|
|
row.key = query.value(2).toString().toStdString();
|
|
}
|
|
row.content_json = query.value(3).toString().toStdString();
|
|
row.metadata_json = query.value(4).toString().toStdString();
|
|
if (!query.value(5).isNull()) {
|
|
row.text = query.value(5).toString().toStdString();
|
|
}
|
|
row.tags = parsePgTextArray(query.value(6).toString());
|
|
row.revision = query.value(7).toInt();
|
|
row.created_at = fromSqlTimestamp(query.value(8));
|
|
row.expires_at = fromSqlTimestampOptional(query.value(9));
|
|
return row;
|
|
}
|
|
|
|
std::vector<ItemRow> PgDal::fetchContext(const std::string& namespaceId,
|
|
const std::optional<std::string>& key,
|
|
const std::vector<std::string>& tags,
|
|
const std::optional<std::string>& sinceIso,
|
|
int limit) {
|
|
if (!useInMemory_ && hasDatabase()) {
|
|
return sqlFetchContext(namespaceId, key, tags, sinceIso, limit);
|
|
}
|
|
|
|
std::vector<ItemRow> results;
|
|
auto bucketIt = itemsByNamespace_.find(namespaceId);
|
|
if (bucketIt == itemsByNamespace_.end()) {
|
|
return results;
|
|
}
|
|
|
|
const auto sinceTp = parseIsoTimestamp(sinceIso);
|
|
|
|
for (const auto& itemId : bucketIt->second) {
|
|
auto it = items_.find(itemId);
|
|
if (it == items_.end()) continue;
|
|
|
|
const ItemRow& row = it->second;
|
|
if (key && (!row.key || *row.key != *key)) {
|
|
continue;
|
|
}
|
|
if (!tags.empty() && !tagsMatch(row.tags, tags)) {
|
|
continue;
|
|
}
|
|
if (sinceTp && row.created_at < *sinceTp) {
|
|
continue;
|
|
}
|
|
results.push_back(row);
|
|
}
|
|
|
|
std::sort(results.begin(), results.end(), [](const ItemRow& a, const ItemRow& b) {
|
|
return a.created_at > b.created_at;
|
|
});
|
|
|
|
if (static_cast<int>(results.size()) > limit) {
|
|
results.resize(static_cast<std::size_t>(limit));
|
|
}
|
|
return results;
|
|
}
|
|
|
|
std::vector<ItemRow> PgDal::sqlFetchContext(const std::string& namespaceId,
|
|
const std::optional<std::string>& key,
|
|
const std::vector<std::string>& tags,
|
|
const std::optional<std::string>& sinceIso,
|
|
int limit) const {
|
|
QSqlDatabase db = database();
|
|
QString queryStr = QStringLiteral(
|
|
"SELECT id::text, namespace_id::text, key, content, metadata::text, text, tags::text[], revision, created_at, expires_at "
|
|
"FROM memory_items WHERE namespace_id = :ns::uuid AND deleted_at IS NULL");
|
|
|
|
if (key && !key->empty()) {
|
|
queryStr += QStringLiteral(" AND key = :key");
|
|
}
|
|
if (!tags.empty()) {
|
|
queryStr += QStringLiteral(" AND tags @> :tags::text[]");
|
|
}
|
|
if (sinceIso && !sinceIso->empty()) {
|
|
queryStr += QStringLiteral(" AND created_at >= :since");
|
|
}
|
|
queryStr += QStringLiteral(" ORDER BY created_at DESC LIMIT :limit");
|
|
|
|
QSqlQuery query(db);
|
|
query.prepare(queryStr);
|
|
query.bindValue(QStringLiteral(":ns"), QString::fromStdString(namespaceId));
|
|
query.bindValue(QStringLiteral(":limit"), limit);
|
|
|
|
if (key && !key->empty()) {
|
|
query.bindValue(QStringLiteral(":key"), QString::fromStdString(*key));
|
|
}
|
|
if (!tags.empty()) {
|
|
query.bindValue(QStringLiteral(":tags"), QString::fromStdString(toPgArrayLiteral(tags)));
|
|
}
|
|
if (sinceIso && !sinceIso->empty()) {
|
|
const auto sinceTp = parseIsoTimestamp(sinceIso);
|
|
query.bindValue(QStringLiteral(":since"), toSqlTimestamp(sinceTp));
|
|
}
|
|
|
|
if (!query.exec()) {
|
|
throw std::runtime_error(query.lastError().text().toStdString());
|
|
}
|
|
|
|
std::vector<ItemRow> rows;
|
|
while (query.next()) {
|
|
ItemRow row;
|
|
row.id = query.value(0).toString().toStdString();
|
|
row.namespace_id = query.value(1).toString().toStdString();
|
|
if (!query.value(2).isNull()) {
|
|
row.key = query.value(2).toString().toStdString();
|
|
}
|
|
row.content_json = query.value(3).toString().toStdString();
|
|
row.metadata_json = query.value(4).toString().toStdString();
|
|
if (!query.value(5).isNull()) {
|
|
row.text = query.value(5).toString().toStdString();
|
|
}
|
|
row.tags = parsePgTextArray(query.value(6).toString());
|
|
row.revision = query.value(7).toInt();
|
|
row.created_at = fromSqlTimestamp(query.value(8));
|
|
row.expires_at = fromSqlTimestampOptional(query.value(9));
|
|
rows.push_back(std::move(row));
|
|
}
|
|
return rows;
|
|
}
|
|
|
|
std::pair<std::string, int> PgDal::upsertItem(
|
|
const std::string& namespace_id,
|
|
const std::optional<std::string>& key,
|
|
const std::string& content,
|
|
const std::string& metadata_json,
|
|
const std::vector<std::string>& tags) {
|
|
ItemRow row;
|
|
row.namespace_id = namespace_id;
|
|
row.key = key;
|
|
row.content_json = content;
|
|
row.metadata_json = metadata_json.empty() ? "{}" : metadata_json;
|
|
if (!content.empty()) {
|
|
row.text = content;
|
|
}
|
|
row.tags = tags;
|
|
row.created_at = std::chrono::system_clock::now();
|
|
|
|
if (!useInMemory_ && hasDatabase()) {
|
|
return sqlUpsertItem(row);
|
|
}
|
|
|
|
const std::string id = upsertItem(row);
|
|
const auto stored = items_.find(id);
|
|
const int revision = stored != items_.end() ? stored->second.revision : 1;
|
|
return {id, revision};
|
|
}
|
|
|
|
std::string PgDal::insertChunk(const std::string& item_id,
|
|
int seq,
|
|
const std::string& content) {
|
|
ChunkRow row;
|
|
row.item_id = item_id;
|
|
row.ord = seq;
|
|
row.text = content;
|
|
auto ids = upsertChunks(std::vector<ChunkRow>{row});
|
|
return ids.empty() ? std::string() : ids.front();
|
|
}
|
|
|
|
void PgDal::insertEmbedding(const Embedding& embedding) {
|
|
EmbeddingRow row;
|
|
row.chunk_id = embedding.chunk_id;
|
|
row.model = embedding.model;
|
|
row.dim = embedding.dim;
|
|
row.vector = embedding.vector;
|
|
upsertEmbeddings(std::vector<EmbeddingRow>{row});
|
|
}
|
|
|
|
std::vector<std::string> PgDal::hybridSearch(const std::vector<float>& query_vec,
|
|
const std::string& model,
|
|
const std::string& namespace_id,
|
|
const std::string& query_text,
|
|
int k) {
|
|
(void)model;
|
|
|
|
if (!useInMemory_ && hasDatabase()) {
|
|
return sqlHybridSearch(query_vec, model, namespace_id, query_text, k);
|
|
}
|
|
|
|
std::vector<std::string> results;
|
|
auto textMatches = searchText(namespace_id, query_text, k);
|
|
for (const auto& item : textMatches) {
|
|
results.push_back(item.id);
|
|
if (static_cast<int>(results.size()) >= k) {
|
|
return results;
|
|
}
|
|
}
|
|
|
|
auto vectorMatches = searchVector(namespace_id, query_vec, k);
|
|
for (const auto& pair : vectorMatches) {
|
|
if (!idsContains(results, pair.first)) {
|
|
results.push_back(pair.first);
|
|
}
|
|
if (static_cast<int>(results.size()) >= k) break;
|
|
}
|
|
|
|
return results;
|
|
}
|
|
|
|
std::vector<std::string> PgDal::sqlHybridSearch(const std::vector<float>& query_vec,
|
|
const std::string& model,
|
|
const std::string& namespace_id,
|
|
const std::string& query_text,
|
|
int k) {
|
|
(void)model;
|
|
std::unordered_set<std::string> seen;
|
|
std::vector<std::string> results;
|
|
|
|
auto textMatches = sqlSearchText(namespace_id, query_text, k);
|
|
for (const auto& item : textMatches) {
|
|
results.push_back(item.id);
|
|
seen.insert(item.id);
|
|
if (static_cast<int>(results.size()) >= k) {
|
|
return results;
|
|
}
|
|
}
|
|
|
|
if (!query_vec.empty()) {
|
|
auto vectorMatches = sqlSearchVector(namespace_id, query_vec, k);
|
|
for (const auto& pair : vectorMatches) {
|
|
if (seen.count(pair.first)) continue;
|
|
results.push_back(pair.first);
|
|
seen.insert(pair.first);
|
|
if (static_cast<int>(results.size()) >= k) break;
|
|
}
|
|
}
|
|
|
|
return results;
|
|
}
|
|
|
|
std::string PgDal::allocateId(std::size_t& counter, const std::string& prefix) {
|
|
return prefix + std::to_string(counter++);
|
|
}
|
|
|
|
std::string PgDal::toLower(const std::string& value) {
|
|
std::string lowered = value;
|
|
std::transform(lowered.begin(), lowered.end(), lowered.begin(),
|
|
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
|
|
return lowered;
|
|
}
|
|
|
|
std::string PgDal::escapePgArrayElement(const std::string& value) {
|
|
std::string escaped;
|
|
escaped.reserve(value.size());
|
|
for (char c : value) {
|
|
if (c == '"' || c == '\\') {
|
|
escaped.push_back('\\');
|
|
}
|
|
escaped.push_back(c);
|
|
}
|
|
return escaped;
|
|
}
|
|
|
|
std::string PgDal::toPgArrayLiteral(const std::vector<std::string>& values) {
|
|
if (values.empty()) {
|
|
return "{}";
|
|
}
|
|
std::string out = "{";
|
|
for (std::size_t i = 0; i < values.size(); ++i) {
|
|
if (i) out += ',';
|
|
out += '"';
|
|
out += escapePgArrayElement(values[i]);
|
|
out += '"';
|
|
}
|
|
out += "}";
|
|
return out;
|
|
}
|
|
|
|
std::string PgDal::toPgVectorLiteral(const std::vector<float>& values) {
|
|
if (values.empty()) {
|
|
return "[]";
|
|
}
|
|
std::string out = "[";
|
|
for (std::size_t i = 0; i < values.size(); ++i) {
|
|
if (i) out += ',';
|
|
out += std::to_string(values[i]);
|
|
}
|
|
out += "]";
|
|
return out;
|
|
}
|
|
|
|
std::vector<std::string> PgDal::parsePgTextArray(const QString& value) {
|
|
std::vector<std::string> tags;
|
|
QString trimmed = value.trimmed();
|
|
if (!trimmed.startsWith(QLatin1Char('{')) || !trimmed.endsWith(QLatin1Char('}'))) {
|
|
return tags;
|
|
}
|
|
trimmed = trimmed.mid(1, trimmed.size() - 2);
|
|
QString current;
|
|
bool inQuotes = false;
|
|
bool escape = false;
|
|
for (QChar ch : trimmed) {
|
|
if (escape) {
|
|
current.append(ch);
|
|
escape = false;
|
|
continue;
|
|
}
|
|
if (ch == QLatin1Char('\\')) {
|
|
escape = true;
|
|
continue;
|
|
}
|
|
if (ch == QLatin1Char('"')) {
|
|
inQuotes = !inQuotes;
|
|
continue;
|
|
}
|
|
if (!inQuotes && ch == QLatin1Char(',')) {
|
|
tags.push_back(current.toStdString());
|
|
current.clear();
|
|
continue;
|
|
}
|
|
current.append(ch);
|
|
}
|
|
if (!current.isEmpty()) {
|
|
tags.push_back(current.toStdString());
|
|
}
|
|
return tags;
|
|
}
|
|
|
|
} // namespace kom
|