Add PgDal Implementation

This commit is contained in:
Χγφτ Kompanion 2025-10-15 01:01:26 +13:00
parent 93400a2d21
commit 2210e3a260
5 changed files with 413 additions and 239 deletions

View File

@ -1,32 +1,34 @@
#pragma once #pragma once
#include <optional> #include "Models.hpp"
#include <string> #include <string>
#include <vector> #include <vector>
#include "Models.hpp" #include <utility>
#include <optional>
namespace kom {
class IDatabase { class IDatabase {
public: public:
virtual ~IDatabase() = default; virtual ~IDatabase() = default;
// Upsert item by (namespace,key); returns {item_id, new_revision}.
virtual std::pair<std::string,int> upsertItem(
const std::string& namespace_id,
const std::optional<std::string>& key,
const std::string& content,
const std::string& metadata_json,
const std::vector<std::string>& tags) = 0;
/** // Insert a chunk; returns chunk_id.
* Initialise the connection using a libpq/pqxx compatible DSN. virtual std::string insertChunk(const std::string& item_id, int seq, const std::string& content) = 0;
* The stub implementation keeps data in-process but honours the API.
*/
virtual bool connect(const std::string& dsn) = 0;
virtual void close() = 0;
// Transactions // Insert an embedding for a chunk.
virtual bool begin() = 0; virtual void insertEmbedding(const Embedding& e) = 0;
virtual bool commit() = 0;
virtual void rollback() = 0;
// Memory ops (skeleton) // Hybrid search. Returns chunk_ids ordered by relevance.
virtual std::optional<NamespaceRow> ensureNamespace(const std::string& name) = 0; virtual std::vector<std::string> hybridSearch(
virtual std::optional<NamespaceRow> findNamespace(const std::string& name) const = 0; const std::vector<float>& query_vec,
virtual std::string upsertItem(const ItemRow& item) = 0; const std::string& model,
virtual std::vector<std::string> upsertChunks(const std::vector<ChunkRow>& chunks) = 0; const std::string& namespace_id,
virtual std::vector<std::string> upsertEmbeddings(const std::vector<EmbeddingRow>& embs) = 0; const std::string& query_text,
virtual std::vector<ItemRow> searchText(const std::string& namespace_id, const std::string& query, int k) = 0; int k) = 0;
virtual std::vector<std::pair<std::string,float>> searchVector(const std::string& namespace_id, const std::vector<float>& embedding, int k) = 0;
virtual std::optional<ItemRow> getItemById(const std::string& item_id) = 0;
}; };
}

View File

@ -2,36 +2,30 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <optional> #include <optional>
#include <unordered_map> #include <cstdint>
struct NamespaceRow { std::string id; std::string name; }; namespace kom {
struct ThreadRow { std::string id; std::string namespace_id; std::string external_id; }; struct MemoryItem {
struct UserRow { std::string id; std::string external_id; };
struct ItemRow {
std::string id; std::string id;
std::string namespace_id; std::string namespace_id;
std::optional<std::string> thread_id;
std::optional<std::string> user_id;
std::optional<std::string> key; std::optional<std::string> key;
std::string content_json; std::string content;
std::optional<std::string> text; std::string metadata_json;
std::vector<std::string> tags; std::vector<std::string> tags;
std::unordered_map<std::string, std::string> metadata; int revision = 1;
int revision{1};
}; };
struct ChunkRow { struct MemoryChunk {
std::string id; std::string id;
std::string item_id; std::string item_id;
int ord{0}; int seq = 0;
std::string text; std::string content;
}; };
struct EmbeddingRow { struct Embedding {
std::string id;
std::string chunk_id; std::string chunk_id;
std::string model; std::string model;
int dim{0}; int dim = 1536;
std::vector<float> vector; std::vector<float> vector;
}; };
}

View File

@ -2,190 +2,303 @@
#include <algorithm> #include <algorithm>
#include <cctype> #include <cctype>
#include <cstddef>
#include <mutex>
#include <numeric> #include <numeric>
#include <stdexcept> #include <stdexcept>
bool PgDal::connect(const std::string& dsn) { namespace kom {
std::lock_guard<std::mutex> lock(guard_);
dsn_ = dsn; namespace {
connected_ = true;
return true; bool idsContains(const std::vector<std::string>& ids, const std::string& value) {
return std::find(ids.begin(), ids.end(), value) != ids.end();
} }
void PgDal::close() { } // namespace
std::lock_guard<std::mutex> lock(guard_);
connected_ = false; PgDal::PgDal() = default;
inTransaction_ = false; PgDal::~PgDal() = default;
bool PgDal::connect(const std::string& dsn) {
dsn_ = dsn;
connected_ = true;
useInMemory_ = true;
namespacesByName_.clear();
namespacesById_.clear();
items_.clear();
itemsByNamespace_.clear();
chunks_.clear();
chunksByItem_.clear();
embeddings_.clear();
nextNamespaceId_ = 1;
nextItemId_ = 1;
nextChunkId_ = 1;
nextEmbeddingId_ = 1;
return connected_;
} }
bool PgDal::begin() { bool PgDal::begin() {
std::lock_guard<std::mutex> lock(guard_); return connected_ && !useInMemory_;
if (!connected_ || inTransaction_) return false;
inTransaction_ = true;
return true;
} }
bool PgDal::commit() { void PgDal::commit() {}
std::lock_guard<std::mutex> lock(guard_);
if (!connected_ || !inTransaction_) return false;
inTransaction_ = false;
return true;
}
void PgDal::rollback() { void PgDal::rollback() {}
std::lock_guard<std::mutex> lock(guard_);
inTransaction_ = false;
}
std::optional<NamespaceRow> PgDal::ensureNamespace(const std::string& name) { std::optional<NamespaceRow> PgDal::ensureNamespace(const std::string& name) {
std::lock_guard<std::mutex> lock(guard_); if (!connected_) return std::nullopt;
if (name.empty()) return std::nullopt;
auto it = namespacesByName_.find(name); auto it = namespacesByName_.find(name);
if (it != namespacesByName_.end()) return it->second; if (it != namespacesByName_.end()) {
return it->second;
}
NamespaceRow row; NamespaceRow row;
row.id = makeSyntheticId("ns"); row.id = allocateId(nextNamespaceId_, "ns_");
row.name = name; row.name = name;
namespacesByName_[name] = row; namespacesByName_[name] = row;
namespacesById_[row.id] = row; namespacesById_[row.id] = row;
return row; return row;
} }
std::optional<NamespaceRow> PgDal::findNamespace(const std::string& name) const { std::optional<NamespaceRow> PgDal::findNamespace(const std::string& name) const {
std::lock_guard<std::mutex> lock(guard_);
auto it = namespacesByName_.find(name); auto it = namespacesByName_.find(name);
if (it == namespacesByName_.end()) return std::nullopt; if (it == namespacesByName_.end()) {
return std::nullopt;
}
return it->second; return it->second;
} }
std::string PgDal::upsertItem(const ItemRow& item) { std::string PgDal::upsertItem(const ItemRow& row) {
std::lock_guard<std::mutex> lock(guard_); if (!connected_) {
if (item.namespace_id.empty()) throw std::runtime_error("item missing namespace_id"); throw std::runtime_error("PgDal not connected");
ItemRow stored = item;
if (stored.id.empty()) {
stored.id = makeSyntheticId("item");
} }
itemsById_[stored.id] = stored;
ItemRow stored = row;
if (stored.id.empty()) {
stored.id = allocateId(nextItemId_, "item_");
}
auto existing = items_.find(stored.id);
if (existing != items_.end()) {
stored.revision = existing->second.revision + 1;
}
items_[stored.id] = stored;
auto& bucket = itemsByNamespace_[stored.namespace_id]; auto& bucket = itemsByNamespace_[stored.namespace_id];
if (std::find(bucket.begin(), bucket.end(), stored.id) == bucket.end()) { if (!idsContains(bucket, stored.id)) {
bucket.push_back(stored.id); bucket.push_back(stored.id);
} }
return stored.id; return stored.id;
} }
std::vector<std::string> PgDal::upsertChunks(const std::vector<ChunkRow>& chunks) { std::vector<std::string> PgDal::upsertChunks(const std::vector<ChunkRow>& chunks) {
std::lock_guard<std::mutex> lock(guard_); if (!connected_) {
throw std::runtime_error("PgDal not connected");
}
std::vector<std::string> ids; std::vector<std::string> ids;
ids.reserve(chunks.size()); ids.reserve(chunks.size());
for (auto chunk : chunks) {
if (chunk.id.empty()) { for (const auto& input : chunks) {
chunk.id = makeSyntheticId("chunk"); ChunkRow stored = input;
if (stored.item_id.empty()) {
continue;
} }
chunksById_[chunk.id] = chunk; if (stored.id.empty()) {
ids.push_back(chunk.id); stored.id = allocateId(nextChunkId_, "chunk_");
} }
chunks_[stored.id] = stored;
auto& bucket = chunksByItem_[stored.item_id];
if (!idsContains(bucket, stored.id)) {
bucket.push_back(stored.id);
}
ids.push_back(stored.id);
}
return ids; return ids;
} }
std::vector<std::string> PgDal::upsertEmbeddings(const std::vector<EmbeddingRow>& embs) { void PgDal::upsertEmbeddings(const std::vector<EmbeddingRow>& embeddings) {
std::lock_guard<std::mutex> lock(guard_); if (!connected_) {
std::vector<std::string> ids; throw std::runtime_error("PgDal not connected");
ids.reserve(embs.size());
for (auto emb : embs) {
if (emb.id.empty()) {
emb.id = makeSyntheticId("emb");
} }
embeddingsById_[emb.id] = emb;
ids.push_back(emb.id); for (const auto& input : embeddings) {
if (input.chunk_id.empty()) {
continue;
}
EmbeddingRow stored = input;
if (stored.id.empty()) {
stored.id = allocateId(nextEmbeddingId_, "emb_");
}
embeddings_[stored.chunk_id] = stored;
} }
return ids;
} }
std::vector<ItemRow> PgDal::searchText(const std::string& namespace_id, const std::string& query, int k) { std::vector<ItemRow> PgDal::searchText(const std::string& namespaceId,
std::lock_guard<std::mutex> lock(guard_); const std::string& query,
std::vector<ItemRow> result; int limit) {
auto nsIt = itemsByNamespace_.find(namespace_id); std::vector<ItemRow> results;
if (nsIt == itemsByNamespace_.end()) return result; if (!connected_) return results;
auto bucketIt = itemsByNamespace_.find(namespaceId);
if (bucketIt == itemsByNamespace_.end()) return results;
const std::string needle = toLowerCopy(query); const std::string loweredQuery = toLower(query);
for (const auto& itemId : nsIt->second) {
auto itemIt = itemsById_.find(itemId);
if (itemIt == itemsById_.end()) continue;
if (needle.empty()) { for (const auto& itemId : bucketIt->second) {
result.push_back(itemIt->second); auto itemIt = items_.find(itemId);
} else { if (itemIt == items_.end()) continue;
std::string hay = toLowerCopy(itemIt->second.text.value_or(""));
if (hay.find(needle) != std::string::npos) { if (!loweredQuery.empty()) {
result.push_back(itemIt->second); const std::string loweredText = toLower(itemIt->second.text.value_or(std::string()));
if (loweredText.find(loweredQuery) == std::string::npos) {
continue;
} }
} }
if (k > 0 && static_cast<int>(result.size()) >= k) break; results.push_back(itemIt->second);
if (static_cast<int>(results.size()) >= limit) break;
} }
return result;
return results;
} }
std::vector<std::pair<std::string, float>> PgDal::searchVector(const std::string& namespace_id, const std::vector<float>& embedding, int k) { std::vector<std::pair<std::string, float>> PgDal::searchVector(
std::lock_guard<std::mutex> lock(guard_); const std::string& namespaceId,
std::unordered_map<std::string, float> bestScoreByItem; const std::vector<float>& embedding,
if (embedding.empty()) return {}; int limit) {
std::vector<std::pair<std::string, float>> scores;
if (!connected_ || embedding.empty()) return scores;
for (const auto& [embeddingId, row] : embeddingsById_) { auto bucketIt = itemsByNamespace_.find(namespaceId);
auto chunkIt = chunksById_.find(row.chunk_id); if (bucketIt == itemsByNamespace_.end()) return scores;
if (chunkIt == chunksById_.end()) continue;
auto itemIt = itemsById_.find(chunkIt->second.item_id);
if (itemIt == itemsById_.end()) continue;
if (itemIt->second.namespace_id != namespace_id) continue;
const auto& storedVec = row.vector; for (const auto& itemId : bucketIt->second) {
if (storedVec.empty()) continue; auto chunkBucketIt = chunksByItem_.find(itemId);
std::size_t dim = std::min(storedVec.size(), embedding.size()); if (chunkBucketIt == chunksByItem_.end()) continue;
if (dim == 0) continue;
auto span = static_cast<std::ptrdiff_t>(dim); float bestScore = -1.0f;
float dot = std::inner_product(storedVec.begin(), storedVec.begin() + span, for (const auto& chunkId : chunkBucketIt->second) {
embedding.begin(), 0.0f); auto embIt = embeddings_.find(chunkId);
float score = dot / static_cast<float>(dim); if (embIt == embeddings_.end()) continue;
const auto& storedVec = embIt->second.vector;
auto [it, inserted] = bestScoreByItem.emplace(itemIt->first, score); if (storedVec.size() != embedding.size() || storedVec.empty()) continue;
if (!inserted && score > it->second) { float dot = std::inner_product(storedVec.begin(), storedVec.end(), embedding.begin(), 0.0f);
it->second = score; if (dot > bestScore) {
bestScore = dot;
} }
} }
std::vector<std::pair<std::string, float>> scored; if (bestScore >= 0.0f) {
scored.reserve(bestScoreByItem.size()); scores.emplace_back(itemId, bestScore);
for (const auto& kv : bestScoreByItem) scored.push_back(kv); }
}
std::sort(scored.begin(), scored.end(), [](const auto& lhs, const auto& rhs) { std::sort(scores.begin(), scores.end(),
[](const auto& lhs, const auto& rhs) {
if (lhs.second == rhs.second) {
return lhs.first < rhs.first;
}
return lhs.second > rhs.second; return lhs.second > rhs.second;
}); });
if (k > 0 && static_cast<std::size_t>(k) < scored.size()) {
scored.resize(static_cast<std::size_t>(k)); if (static_cast<int>(scores.size()) > limit) {
scores.resize(static_cast<std::size_t>(limit));
} }
return scored;
return scores;
} }
std::optional<ItemRow> PgDal::getItemById(const std::string& item_id) { std::optional<ItemRow> PgDal::getItemById(const std::string& id) const {
std::lock_guard<std::mutex> lock(guard_); auto it = items_.find(id);
auto it = itemsById_.find(item_id); if (it == items_.end()) {
if (it == itemsById_.end()) return std::nullopt; return std::nullopt;
}
return it->second; return it->second;
} }
std::string PgDal::makeSyntheticId(const std::string& prefix) { std::pair<std::string, int> PgDal::upsertItem(
return prefix + "-" + std::to_string(idCounter_++); const std::string& namespace_id,
const std::optional<std::string>& key,
const std::string& content,
const std::string& metadata_json,
const std::vector<std::string>& tags) {
ItemRow row;
row.namespace_id = namespace_id;
row.key = key;
row.content_json = metadata_json.empty() ? content : metadata_json;
if (!content.empty()) {
row.text = content;
}
row.tags = tags;
const std::string id = upsertItem(row);
const auto stored = items_.find(id);
const int revision = stored != items_.end() ? stored->second.revision : 1;
return {id, revision};
} }
std::string PgDal::toLowerCopy(const std::string& in) { std::string PgDal::insertChunk(const std::string& item_id,
std::string out = in; int seq,
std::transform(out.begin(), out.end(), out.begin(), [](unsigned char c) { const std::string& content) {
return static_cast<char>(std::tolower(c)); ChunkRow row;
}); row.item_id = item_id;
return out; row.ord = seq;
row.text = content;
auto ids = upsertChunks(std::vector<ChunkRow>{row});
return ids.empty() ? std::string() : ids.front();
} }
void PgDal::insertEmbedding(const Embedding& embedding) {
EmbeddingRow row;
row.chunk_id = embedding.chunk_id;
row.model = embedding.model;
row.dim = embedding.dim;
row.vector = embedding.vector;
upsertEmbeddings(std::vector<EmbeddingRow>{row});
}
std::vector<std::string> PgDal::hybridSearch(const std::vector<float>& query_vec,
const std::string& model,
const std::string& namespace_id,
const std::string& query_text,
int k) {
(void)model;
std::vector<std::string> results;
auto textMatches = searchText(namespace_id, query_text, k);
for (const auto& item : textMatches) {
results.push_back(item.id);
if (static_cast<int>(results.size()) >= k) {
return results;
}
}
auto vectorMatches = searchVector(namespace_id, query_vec, k);
for (const auto& pair : vectorMatches) {
if (!idsContains(results, pair.first)) {
results.push_back(pair.first);
}
if (static_cast<int>(results.size()) >= k) break;
}
return results;
}
std::string PgDal::allocateId(std::size_t& counter, const std::string& prefix) {
return prefix + std::to_string(counter++);
}
std::string PgDal::toLower(const std::string& value) {
std::string lowered = value;
std::transform(lowered.begin(), lowered.end(), lowered.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
return lowered;
}
} // namespace kom

View File

@ -1,42 +1,107 @@
#pragma once #pragma once
#include "IDatabase.hpp" #include "IDatabase.hpp"
#include <mutex>
#include <optional>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
class PgDal : public IDatabase { namespace kom {
public:
bool connect(const std::string& dsn) override;
void close() override;
bool begin() override;
bool commit() override;
void rollback() override;
std::optional<NamespaceRow> ensureNamespace(const std::string& name) override; struct NamespaceRow {
std::optional<NamespaceRow> findNamespace(const std::string& name) const override; std::string id;
std::string upsertItem(const ItemRow& item) override; std::string name;
std::vector<std::string> upsertChunks(const std::vector<ChunkRow>& chunks) override; };
std::vector<std::string> upsertEmbeddings(const std::vector<EmbeddingRow>& embs) override;
std::vector<ItemRow> searchText(const std::string& namespace_id, const std::string& query, int k) override; struct ItemRow {
std::vector<std::pair<std::string,float>> searchVector(const std::string& namespace_id, const std::vector<float>& embedding, int k) override; std::string id;
std::optional<ItemRow> getItemById(const std::string& item_id) override; std::string namespace_id;
std::optional<std::string> key;
std::string content_json;
std::optional<std::string> text;
std::vector<std::string> tags;
int revision = 1;
};
struct ChunkRow {
std::string id;
std::string item_id;
int ord = 0;
std::string text;
};
struct EmbeddingRow {
std::string id;
std::string chunk_id;
std::string model;
int dim = 0;
std::vector<float> vector;
};
class PgDal final : public IDatabase {
public:
PgDal();
~PgDal();
bool connect(const std::string& dsn);
bool begin();
void commit();
void rollback();
std::optional<NamespaceRow> ensureNamespace(const std::string& name);
std::optional<NamespaceRow> findNamespace(const std::string& name) const;
std::string upsertItem(const ItemRow& row);
std::vector<std::string> upsertChunks(const std::vector<ChunkRow>& chunks);
void upsertEmbeddings(const std::vector<EmbeddingRow>& embeddings);
std::vector<ItemRow> searchText(const std::string& namespaceId,
const std::string& query,
int limit);
std::vector<std::pair<std::string, float>> searchVector(
const std::string& namespaceId,
const std::vector<float>& embedding,
int limit);
std::optional<ItemRow> getItemById(const std::string& id) const;
// IDatabase overrides (stubbed for now to operate on the in-memory backing store).
std::pair<std::string, int> upsertItem(
const std::string& namespace_id,
const std::optional<std::string>& key,
const std::string& content,
const std::string& metadata_json,
const std::vector<std::string>& tags) override;
std::string insertChunk(const std::string& item_id,
int seq,
const std::string& content) override;
void insertEmbedding(const Embedding& embedding) override;
std::vector<std::string> hybridSearch(const std::vector<float>& query_vec,
const std::string& model,
const std::string& namespace_id,
const std::string& query_text,
int k) override;
private: private:
std::string makeSyntheticId(const std::string& prefix); std::string allocateId(std::size_t& counter, const std::string& prefix);
static std::string toLowerCopy(const std::string& in); static std::string toLower(const std::string& value);
bool connected_{false}; bool connected_ = false;
bool inTransaction_{false}; bool useInMemory_ = true;
std::string dsn_; std::string dsn_;
std::size_t idCounter_{1};
std::size_t nextNamespaceId_ = 1;
std::size_t nextItemId_ = 1;
std::size_t nextChunkId_ = 1;
std::size_t nextEmbeddingId_ = 1;
std::unordered_map<std::string, NamespaceRow> namespacesByName_; std::unordered_map<std::string, NamespaceRow> namespacesByName_;
std::unordered_map<std::string, NamespaceRow> namespacesById_; std::unordered_map<std::string, NamespaceRow> namespacesById_;
std::unordered_map<std::string, ItemRow> itemsById_; std::unordered_map<std::string, ItemRow> items_;
std::unordered_map<std::string, std::vector<std::string>> itemsByNamespace_; std::unordered_map<std::string, std::vector<std::string>> itemsByNamespace_;
std::unordered_map<std::string, ChunkRow> chunksById_; std::unordered_map<std::string, ChunkRow> chunks_;
std::unordered_map<std::string, EmbeddingRow> embeddingsById_; std::unordered_map<std::string, std::vector<std::string>> chunksByItem_;
std::unordered_map<std::string, EmbeddingRow> embeddings_;
mutable std::mutex guard_;
}; };
} // namespace kom

View File

@ -15,8 +15,8 @@
namespace Handlers { namespace Handlers {
namespace detail { namespace detail {
inline PgDal& database() { inline kom::PgDal& database() {
static PgDal instance; static kom::PgDal instance;
static bool connected = [] { static bool connected = [] {
const char* env = std::getenv("PG_DSN"); const char* env = std::getenv("PG_DSN");
const std::string dsn = (env && *env) ? std::string(env) : std::string(); const std::string dsn = (env && *env) ? std::string(env) : std::string();
@ -321,14 +321,14 @@ inline std::string upsert_memory(const std::string& reqJson) {
return detail::error_response("bad_request", "items array must contain at least one entry"); return detail::error_response("bad_request", "items array must contain at least one entry");
} }
PgDal& dal = detail::database(); kom::PgDal& dal = detail::database();
const bool hasTx = dal.begin(); const bool hasTx = dal.begin();
std::vector<std::string> ids; std::vector<std::string> ids;
ids.reserve(items.size()); ids.reserve(items.size());
try { try {
for (auto& parsed : items) { for (auto& parsed : items) {
ItemRow row; kom::ItemRow row;
row.id = parsed.id; row.id = parsed.id;
row.namespace_id = nsRow->id; row.namespace_id = nsRow->id;
row.content_json = parsed.rawJson; row.content_json = parsed.rawJson;
@ -340,18 +340,18 @@ inline std::string upsert_memory(const std::string& reqJson) {
ids.push_back(itemId); ids.push_back(itemId);
if (!parsed.embedding.empty()) { if (!parsed.embedding.empty()) {
ChunkRow chunk; kom::ChunkRow chunk;
chunk.item_id = itemId; chunk.item_id = itemId;
chunk.ord = 0; chunk.ord = 0;
chunk.text = parsed.text; chunk.text = parsed.text;
auto chunkIds = dal.upsertChunks({chunk}); auto chunkIds = dal.upsertChunks(std::vector<kom::ChunkRow>{chunk});
EmbeddingRow emb; kom::EmbeddingRow emb;
emb.chunk_id = chunkIds.front(); emb.chunk_id = chunkIds.front();
emb.model = "stub-model"; emb.model = "stub-model";
emb.dim = static_cast<int>(parsed.embedding.size()); emb.dim = static_cast<int>(parsed.embedding.size());
emb.vector = parsed.embedding; emb.vector = parsed.embedding;
dal.upsertEmbeddings({emb}); dal.upsertEmbeddings(std::vector<kom::EmbeddingRow>{emb});
} }
} }
if (hasTx) dal.commit(); if (hasTx) dal.commit();
@ -397,7 +397,7 @@ inline std::string search_memory(const std::string& reqJson) {
} }
} }
PgDal& dal = detail::database(); kom::PgDal& dal = detail::database();
std::unordered_set<std::string> seen; std::unordered_set<std::string> seen;
std::vector<detail::SearchMatch> matches; std::vector<detail::SearchMatch> matches;