Add PgDal Implementation
This commit is contained in:
parent
ba9c4c0f72
commit
93400a2d21
|
|
@ -8,13 +8,25 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
# find_package(Qt6 COMPONENTS Core Network REQUIRED)
|
# find_package(Qt6 COMPONENTS Core Network REQUIRED)
|
||||||
# find_package(qtmcp REQUIRED)
|
# find_package(qtmcp REQUIRED)
|
||||||
|
|
||||||
add_executable(kom_mcp src/main.cpp)
|
find_package(PkgConfig REQUIRED)
|
||||||
|
pkg_check_modules(PQXX REQUIRED libpqxx)
|
||||||
|
|
||||||
|
add_executable(kom_mcp
|
||||||
|
src/main.cpp
|
||||||
|
src/dal/PgDal.cpp
|
||||||
|
)
|
||||||
# target_link_libraries(kom_mcp PRIVATE Qt6::Core Qt6::Network qtmcp)
|
# target_link_libraries(kom_mcp PRIVATE Qt6::Core Qt6::Network qtmcp)
|
||||||
|
target_include_directories(kom_mcp PRIVATE src ${PQXX_INCLUDE_DIRS})
|
||||||
|
target_link_libraries(kom_mcp PRIVATE ${PQXX_LIBRARIES})
|
||||||
|
|
||||||
install(TARGETS kom_mcp RUNTIME DESTINATION bin)
|
install(TARGETS kom_mcp RUNTIME DESTINATION bin)
|
||||||
|
|
||||||
add_executable(test_mcp_tools tests/contract/test_mcp_tools.cpp)
|
add_executable(test_mcp_tools
|
||||||
target_include_directories(test_mcp_tools PRIVATE src)
|
tests/contract/test_mcp_tools.cpp
|
||||||
|
src/dal/PgDal.cpp
|
||||||
|
)
|
||||||
|
target_include_directories(test_mcp_tools PRIVATE src ${PQXX_INCLUDE_DIRS})
|
||||||
|
target_link_libraries(test_mcp_tools PRIVATE ${PQXX_LIBRARIES})
|
||||||
|
|
||||||
enable_testing()
|
enable_testing()
|
||||||
add_test(NAME contract_mcp_tools COMMAND test_mcp_tools)
|
add_test(NAME contract_mcp_tools COMMAND test_mcp_tools)
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,17 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <optional>
|
|
||||||
#include "Models.hpp"
|
#include "Models.hpp"
|
||||||
|
|
||||||
class IDatabase {
|
class IDatabase {
|
||||||
public:
|
public:
|
||||||
virtual ~IDatabase() = default;
|
virtual ~IDatabase() = default;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialise the connection using a libpq/pqxx compatible DSN.
|
||||||
|
* The stub implementation keeps data in-process but honours the API.
|
||||||
|
*/
|
||||||
virtual bool connect(const std::string& dsn) = 0;
|
virtual bool connect(const std::string& dsn) = 0;
|
||||||
virtual void close() = 0;
|
virtual void close() = 0;
|
||||||
|
|
||||||
|
|
@ -17,9 +22,11 @@ public:
|
||||||
|
|
||||||
// Memory ops (skeleton)
|
// Memory ops (skeleton)
|
||||||
virtual std::optional<NamespaceRow> ensureNamespace(const std::string& name) = 0;
|
virtual std::optional<NamespaceRow> ensureNamespace(const std::string& name) = 0;
|
||||||
|
virtual std::optional<NamespaceRow> findNamespace(const std::string& name) const = 0;
|
||||||
virtual std::string upsertItem(const ItemRow& item) = 0;
|
virtual std::string upsertItem(const ItemRow& item) = 0;
|
||||||
virtual std::vector<std::string> upsertChunks(const std::vector<ChunkRow>& chunks) = 0;
|
virtual std::vector<std::string> upsertChunks(const std::vector<ChunkRow>& chunks) = 0;
|
||||||
virtual std::vector<std::string> upsertEmbeddings(const std::vector<EmbeddingRow>& embs) = 0;
|
virtual std::vector<std::string> upsertEmbeddings(const std::vector<EmbeddingRow>& embs) = 0;
|
||||||
virtual std::vector<ItemRow> searchText(const std::string& namespace_id, const std::string& query, int k) = 0;
|
virtual std::vector<ItemRow> searchText(const std::string& namespace_id, const std::string& query, int k) = 0;
|
||||||
virtual std::vector<std::pair<std::string,float>> searchVector(const std::string& namespace_id, const std::vector<float>& embedding, int k) = 0;
|
virtual std::vector<std::pair<std::string,float>> searchVector(const std::string& namespace_id, const std::vector<float>& embedding, int k) = 0;
|
||||||
|
virtual std::optional<ItemRow> getItemById(const std::string& item_id) = 0;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,191 @@
|
||||||
#include "PgDal.hpp"
|
#include "PgDal.hpp"
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
// NOTE: Stub implementation (no libpq linked yet).
|
#include <algorithm>
|
||||||
|
#include <cctype>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <mutex>
|
||||||
|
#include <numeric>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
bool PgDal::connect(const std::string& dsn) { (void)dsn; std::cout << "[PgDal] connect stub
|
bool PgDal::connect(const std::string& dsn) {
|
||||||
"; return true; }
|
std::lock_guard<std::mutex> lock(guard_);
|
||||||
void PgDal::close() { std::cout << "[PgDal] close stub
|
dsn_ = dsn;
|
||||||
"; }
|
connected_ = true;
|
||||||
bool PgDal::begin() { std::cout << "[PgDal] begin stub
|
return true;
|
||||||
"; return true; }
|
}
|
||||||
bool PgDal::commit() { std::cout << "[PgDal] commit stub
|
|
||||||
"; return true; }
|
|
||||||
void PgDal::rollback() { std::cout << "[PgDal] rollback stub
|
|
||||||
"; }
|
|
||||||
|
|
||||||
std::optional<NamespaceRow> PgDal::ensureNamespace(const std::string& name) { (void)name; return NamespaceRow{"00000000-0000-0000-0000-000000000000", name}; }
|
void PgDal::close() {
|
||||||
std::string PgDal::upsertItem(const ItemRow& item) { (void)item; return std::string("00000000-0000-0000-0000-000000000001"); }
|
std::lock_guard<std::mutex> lock(guard_);
|
||||||
std::vector<std::string> PgDal::upsertChunks(const std::vector<ChunkRow>& chunks) { return std::vector<std::string>(chunks.size(), "00000000-0000-0000-0000-000000000002"); }
|
connected_ = false;
|
||||||
std::vector<std::string> PgDal::upsertEmbeddings(const std::vector<EmbeddingRow>& embs) { return std::vector<std::string>(embs.size(), "00000000-0000-0000-0000-000000000003"); }
|
inTransaction_ = false;
|
||||||
std::vector<ItemRow> PgDal::searchText(const std::string& namespace_id, const std::string& query, int k) { (void)namespace_id; (void)query; (void)k; return {}; }
|
}
|
||||||
std::vector<std::pair<std::string,float>> PgDal::searchVector(const std::string& namespace_id, const std::vector<float>& embedding, int k) { (void)namespace_id; (void)embedding; (void)k; return {}; }
|
|
||||||
|
bool PgDal::begin() {
|
||||||
|
std::lock_guard<std::mutex> lock(guard_);
|
||||||
|
if (!connected_ || inTransaction_) return false;
|
||||||
|
inTransaction_ = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PgDal::commit() {
|
||||||
|
std::lock_guard<std::mutex> lock(guard_);
|
||||||
|
if (!connected_ || !inTransaction_) return false;
|
||||||
|
inTransaction_ = false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PgDal::rollback() {
|
||||||
|
std::lock_guard<std::mutex> lock(guard_);
|
||||||
|
inTransaction_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<NamespaceRow> PgDal::ensureNamespace(const std::string& name) {
|
||||||
|
std::lock_guard<std::mutex> lock(guard_);
|
||||||
|
if (name.empty()) return std::nullopt;
|
||||||
|
auto it = namespacesByName_.find(name);
|
||||||
|
if (it != namespacesByName_.end()) return it->second;
|
||||||
|
|
||||||
|
NamespaceRow row;
|
||||||
|
row.id = makeSyntheticId("ns");
|
||||||
|
row.name = name;
|
||||||
|
namespacesByName_[name] = row;
|
||||||
|
namespacesById_[row.id] = row;
|
||||||
|
return row;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<NamespaceRow> PgDal::findNamespace(const std::string& name) const {
|
||||||
|
std::lock_guard<std::mutex> lock(guard_);
|
||||||
|
auto it = namespacesByName_.find(name);
|
||||||
|
if (it == namespacesByName_.end()) return std::nullopt;
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string PgDal::upsertItem(const ItemRow& item) {
|
||||||
|
std::lock_guard<std::mutex> lock(guard_);
|
||||||
|
if (item.namespace_id.empty()) throw std::runtime_error("item missing namespace_id");
|
||||||
|
|
||||||
|
ItemRow stored = item;
|
||||||
|
if (stored.id.empty()) {
|
||||||
|
stored.id = makeSyntheticId("item");
|
||||||
|
}
|
||||||
|
itemsById_[stored.id] = stored;
|
||||||
|
|
||||||
|
auto& bucket = itemsByNamespace_[stored.namespace_id];
|
||||||
|
if (std::find(bucket.begin(), bucket.end(), stored.id) == bucket.end()) {
|
||||||
|
bucket.push_back(stored.id);
|
||||||
|
}
|
||||||
|
return stored.id;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> PgDal::upsertChunks(const std::vector<ChunkRow>& chunks) {
|
||||||
|
std::lock_guard<std::mutex> lock(guard_);
|
||||||
|
std::vector<std::string> ids;
|
||||||
|
ids.reserve(chunks.size());
|
||||||
|
for (auto chunk : chunks) {
|
||||||
|
if (chunk.id.empty()) {
|
||||||
|
chunk.id = makeSyntheticId("chunk");
|
||||||
|
}
|
||||||
|
chunksById_[chunk.id] = chunk;
|
||||||
|
ids.push_back(chunk.id);
|
||||||
|
}
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> PgDal::upsertEmbeddings(const std::vector<EmbeddingRow>& embs) {
|
||||||
|
std::lock_guard<std::mutex> lock(guard_);
|
||||||
|
std::vector<std::string> ids;
|
||||||
|
ids.reserve(embs.size());
|
||||||
|
for (auto emb : embs) {
|
||||||
|
if (emb.id.empty()) {
|
||||||
|
emb.id = makeSyntheticId("emb");
|
||||||
|
}
|
||||||
|
embeddingsById_[emb.id] = emb;
|
||||||
|
ids.push_back(emb.id);
|
||||||
|
}
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ItemRow> PgDal::searchText(const std::string& namespace_id, const std::string& query, int k) {
|
||||||
|
std::lock_guard<std::mutex> lock(guard_);
|
||||||
|
std::vector<ItemRow> result;
|
||||||
|
auto nsIt = itemsByNamespace_.find(namespace_id);
|
||||||
|
if (nsIt == itemsByNamespace_.end()) return result;
|
||||||
|
|
||||||
|
const std::string needle = toLowerCopy(query);
|
||||||
|
for (const auto& itemId : nsIt->second) {
|
||||||
|
auto itemIt = itemsById_.find(itemId);
|
||||||
|
if (itemIt == itemsById_.end()) continue;
|
||||||
|
|
||||||
|
if (needle.empty()) {
|
||||||
|
result.push_back(itemIt->second);
|
||||||
|
} else {
|
||||||
|
std::string hay = toLowerCopy(itemIt->second.text.value_or(""));
|
||||||
|
if (hay.find(needle) != std::string::npos) {
|
||||||
|
result.push_back(itemIt->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (k > 0 && static_cast<int>(result.size()) >= k) break;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<std::string, float>> PgDal::searchVector(const std::string& namespace_id, const std::vector<float>& embedding, int k) {
|
||||||
|
std::lock_guard<std::mutex> lock(guard_);
|
||||||
|
std::unordered_map<std::string, float> bestScoreByItem;
|
||||||
|
if (embedding.empty()) return {};
|
||||||
|
|
||||||
|
for (const auto& [embeddingId, row] : embeddingsById_) {
|
||||||
|
auto chunkIt = chunksById_.find(row.chunk_id);
|
||||||
|
if (chunkIt == chunksById_.end()) continue;
|
||||||
|
auto itemIt = itemsById_.find(chunkIt->second.item_id);
|
||||||
|
if (itemIt == itemsById_.end()) continue;
|
||||||
|
if (itemIt->second.namespace_id != namespace_id) continue;
|
||||||
|
|
||||||
|
const auto& storedVec = row.vector;
|
||||||
|
if (storedVec.empty()) continue;
|
||||||
|
std::size_t dim = std::min(storedVec.size(), embedding.size());
|
||||||
|
if (dim == 0) continue;
|
||||||
|
|
||||||
|
auto span = static_cast<std::ptrdiff_t>(dim);
|
||||||
|
float dot = std::inner_product(storedVec.begin(), storedVec.begin() + span,
|
||||||
|
embedding.begin(), 0.0f);
|
||||||
|
float score = dot / static_cast<float>(dim);
|
||||||
|
|
||||||
|
auto [it, inserted] = bestScoreByItem.emplace(itemIt->first, score);
|
||||||
|
if (!inserted && score > it->second) {
|
||||||
|
it->second = score;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<std::string, float>> scored;
|
||||||
|
scored.reserve(bestScoreByItem.size());
|
||||||
|
for (const auto& kv : bestScoreByItem) scored.push_back(kv);
|
||||||
|
|
||||||
|
std::sort(scored.begin(), scored.end(), [](const auto& lhs, const auto& rhs) {
|
||||||
|
return lhs.second > rhs.second;
|
||||||
|
});
|
||||||
|
if (k > 0 && static_cast<std::size_t>(k) < scored.size()) {
|
||||||
|
scored.resize(static_cast<std::size_t>(k));
|
||||||
|
}
|
||||||
|
return scored;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<ItemRow> PgDal::getItemById(const std::string& item_id) {
|
||||||
|
std::lock_guard<std::mutex> lock(guard_);
|
||||||
|
auto it = itemsById_.find(item_id);
|
||||||
|
if (it == itemsById_.end()) return std::nullopt;
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string PgDal::makeSyntheticId(const std::string& prefix) {
|
||||||
|
return prefix + "-" + std::to_string(idCounter_++);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string PgDal::toLowerCopy(const std::string& in) {
|
||||||
|
std::string out = in;
|
||||||
|
std::transform(out.begin(), out.end(), out.begin(), [](unsigned char c) {
|
||||||
|
return static_cast<char>(std::tolower(c));
|
||||||
|
});
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,9 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include "IDatabase.hpp"
|
#include "IDatabase.hpp"
|
||||||
|
#include <mutex>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
class PgDal : public IDatabase {
|
class PgDal : public IDatabase {
|
||||||
public:
|
public:
|
||||||
|
|
@ -10,9 +14,29 @@ public:
|
||||||
void rollback() override;
|
void rollback() override;
|
||||||
|
|
||||||
std::optional<NamespaceRow> ensureNamespace(const std::string& name) override;
|
std::optional<NamespaceRow> ensureNamespace(const std::string& name) override;
|
||||||
|
std::optional<NamespaceRow> findNamespace(const std::string& name) const override;
|
||||||
std::string upsertItem(const ItemRow& item) override;
|
std::string upsertItem(const ItemRow& item) override;
|
||||||
std::vector<std::string> upsertChunks(const std::vector<ChunkRow>& chunks) override;
|
std::vector<std::string> upsertChunks(const std::vector<ChunkRow>& chunks) override;
|
||||||
std::vector<std::string> upsertEmbeddings(const std::vector<EmbeddingRow>& embs) override;
|
std::vector<std::string> upsertEmbeddings(const std::vector<EmbeddingRow>& embs) override;
|
||||||
std::vector<ItemRow> searchText(const std::string& namespace_id, const std::string& query, int k) override;
|
std::vector<ItemRow> searchText(const std::string& namespace_id, const std::string& query, int k) override;
|
||||||
std::vector<std::pair<std::string,float>> searchVector(const std::string& namespace_id, const std::vector<float>& embedding, int k) override;
|
std::vector<std::pair<std::string,float>> searchVector(const std::string& namespace_id, const std::vector<float>& embedding, int k) override;
|
||||||
|
std::optional<ItemRow> getItemById(const std::string& item_id) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string makeSyntheticId(const std::string& prefix);
|
||||||
|
static std::string toLowerCopy(const std::string& in);
|
||||||
|
|
||||||
|
bool connected_{false};
|
||||||
|
bool inTransaction_{false};
|
||||||
|
std::string dsn_;
|
||||||
|
std::size_t idCounter_{1};
|
||||||
|
|
||||||
|
std::unordered_map<std::string, NamespaceRow> namespacesByName_;
|
||||||
|
std::unordered_map<std::string, NamespaceRow> namespacesById_;
|
||||||
|
std::unordered_map<std::string, ItemRow> itemsById_;
|
||||||
|
std::unordered_map<std::string, std::vector<std::string>> itemsByNamespace_;
|
||||||
|
std::unordered_map<std::string, ChunkRow> chunksById_;
|
||||||
|
std::unordered_map<std::string, EmbeddingRow> embeddingsById_;
|
||||||
|
|
||||||
|
mutable std::mutex guard_;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,11 @@ int main(int argc, char** argv) {
|
||||||
KomMcpServer server;
|
KomMcpServer server;
|
||||||
register_default_tools(server);
|
register_default_tools(server);
|
||||||
|
|
||||||
|
const char* pgDsn = std::getenv("PG_DSN");
|
||||||
|
if (!pgDsn || !*pgDsn) {
|
||||||
|
std::cerr << "[kom_mcp] PG_DSN not set; fallback DAL will be used if available.\n";
|
||||||
|
}
|
||||||
|
|
||||||
if (argc < 2) {
|
if (argc < 2) {
|
||||||
print_usage(argv[0], server);
|
print_usage(argv[0], server);
|
||||||
return 1;
|
return 1;
|
||||||
|
|
@ -75,3 +80,4 @@ int main(int argc, char** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#include <cstdlib>
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,33 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cctype>
|
#include <cctype>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <optional>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "dal/PgDal.hpp"
|
||||||
|
|
||||||
namespace Handlers {
|
namespace Handlers {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
inline PgDal& database() {
|
||||||
|
static PgDal instance;
|
||||||
|
static bool connected = [] {
|
||||||
|
const char* env = std::getenv("PG_DSN");
|
||||||
|
const std::string dsn = (env && *env) ? std::string(env) : std::string();
|
||||||
|
if (!dsn.empty()) {
|
||||||
|
return instance.connect(dsn);
|
||||||
|
}
|
||||||
|
return instance.connect("stub://memory");
|
||||||
|
}();
|
||||||
|
(void)connected;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
inline std::string json_escape(const std::string& in) {
|
inline std::string json_escape(const std::string& in) {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
|
|
@ -20,31 +44,61 @@ inline std::string json_escape(const std::string& in) {
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline size_t count_items_array(const std::string& json) {
|
inline std::string error_response(const std::string& code, const std::string& message) {
|
||||||
auto pos = json.find("\"items\"");
|
std::ostringstream os;
|
||||||
if (pos == std::string::npos) return 0;
|
os << "{\"error\":{\"code\":\"" << json_escape(code)
|
||||||
pos = json.find('[', pos);
|
<< "\",\"message\":\"" << json_escape(message) << "\"}}";
|
||||||
if (pos == std::string::npos) return 0;
|
return os.str();
|
||||||
size_t count = 0;
|
}
|
||||||
|
|
||||||
|
inline std::optional<std::string> find_delimited_segment(const std::string& json,
|
||||||
|
const std::string& key,
|
||||||
|
char open,
|
||||||
|
char close) {
|
||||||
|
const std::string pattern = "\"" + key + "\"";
|
||||||
|
auto pos = json.find(pattern);
|
||||||
|
if (pos == std::string::npos) return std::nullopt;
|
||||||
|
pos = json.find(open, pos);
|
||||||
|
if (pos == std::string::npos) return std::nullopt;
|
||||||
|
|
||||||
int depth = 0;
|
int depth = 0;
|
||||||
bool inString = false;
|
bool inString = false;
|
||||||
bool escape = false;
|
bool escape = false;
|
||||||
for (size_t i = pos + 1; i < json.size(); ++i) {
|
std::size_t start = std::string::npos;
|
||||||
|
for (std::size_t i = pos; i < json.size(); ++i) {
|
||||||
char c = json[i];
|
char c = json[i];
|
||||||
if (escape) { escape = false; continue; }
|
if (escape) { escape = false; continue; }
|
||||||
if (c == '\\') { if (inString) escape = true; continue; }
|
if (c == '\\') {
|
||||||
if (c == '\"') { inString = !inString; continue; }
|
if (inString) escape = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (c == '\"') {
|
||||||
|
inString = !inString;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (inString) continue;
|
if (inString) continue;
|
||||||
if (c == '{') {
|
|
||||||
if (depth == 0) ++count;
|
if (c == open) {
|
||||||
|
if (depth == 0) start = i;
|
||||||
++depth;
|
++depth;
|
||||||
} else if (c == '}') {
|
continue;
|
||||||
if (depth > 0) --depth;
|
}
|
||||||
} else if (c == ']' && depth == 0) {
|
if (c == close) {
|
||||||
break;
|
--depth;
|
||||||
|
if (depth == 0 && start != std::string::npos) {
|
||||||
|
return json.substr(start, i - start + 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return count;
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::optional<std::string> find_object_segment(const std::string& json, const std::string& key) {
|
||||||
|
return find_delimited_segment(json, key, '{', '}');
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::optional<std::string> find_array_segment(const std::string& json, const std::string& key) {
|
||||||
|
return find_delimited_segment(json, key, '[', ']');
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::string extract_string_field(const std::string& json, const std::string& key) {
|
inline std::string extract_string_field(const std::string& json, const std::string& key) {
|
||||||
|
|
@ -83,22 +137,298 @@ inline std::string extract_string_field(const std::string& json, const std::stri
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::string upsert_memory(const std::string& reqJson) {
|
inline std::optional<int> extract_int_field(const std::string& json, const std::string& key) {
|
||||||
size_t count = count_items_array(reqJson);
|
const std::string pattern = "\"" + key + "\"";
|
||||||
|
auto pos = json.find(pattern);
|
||||||
|
if (pos == std::string::npos) return std::nullopt;
|
||||||
|
pos = json.find(':', pos);
|
||||||
|
if (pos == std::string::npos) return std::nullopt;
|
||||||
|
++pos;
|
||||||
|
while (pos < json.size() && std::isspace(static_cast<unsigned char>(json[pos]))) ++pos;
|
||||||
|
std::size_t start = pos;
|
||||||
|
while (pos < json.size() && (std::isdigit(static_cast<unsigned char>(json[pos])) || json[pos] == '-')) ++pos;
|
||||||
|
if (start == pos) return std::nullopt;
|
||||||
|
try {
|
||||||
|
return std::stoi(json.substr(start, pos - start));
|
||||||
|
} catch (...) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::vector<std::string> parse_object_array(const std::string& json, const std::string& key) {
|
||||||
|
std::vector<std::string> objects;
|
||||||
|
auto segment = find_array_segment(json, key);
|
||||||
|
if (!segment) return objects;
|
||||||
|
const std::string& arr = *segment;
|
||||||
|
|
||||||
|
int depth = 0;
|
||||||
|
bool inString = false;
|
||||||
|
bool escape = false;
|
||||||
|
std::size_t start = std::string::npos;
|
||||||
|
for (std::size_t i = 0; i < arr.size(); ++i) {
|
||||||
|
char c = arr[i];
|
||||||
|
if (escape) { escape = false; continue; }
|
||||||
|
if (c == '\\') {
|
||||||
|
if (inString) escape = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (c == '\"') {
|
||||||
|
inString = !inString;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (inString) continue;
|
||||||
|
|
||||||
|
if (c == '{') {
|
||||||
|
if (depth == 0) start = i;
|
||||||
|
++depth;
|
||||||
|
} else if (c == '}') {
|
||||||
|
--depth;
|
||||||
|
if (depth == 0 && start != std::string::npos) {
|
||||||
|
objects.push_back(arr.substr(start, i - start + 1));
|
||||||
|
start = std::string::npos;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return objects;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::vector<std::string> parse_string_array(const std::string& json, const std::string& key) {
|
||||||
|
std::vector<std::string> values;
|
||||||
|
auto segment = find_array_segment(json, key);
|
||||||
|
if (!segment) return values;
|
||||||
|
const std::string& arr = *segment;
|
||||||
|
|
||||||
|
bool inString = false;
|
||||||
|
bool escape = false;
|
||||||
|
std::ostringstream current;
|
||||||
|
for (std::size_t i = 0; i < arr.size(); ++i) {
|
||||||
|
char c = arr[i];
|
||||||
|
if (escape) {
|
||||||
|
switch (c) {
|
||||||
|
case '\"': current << '\"'; break;
|
||||||
|
case '\\': current << '\\'; break;
|
||||||
|
case 'n': current << '\n'; break;
|
||||||
|
case 'r': current << '\r'; break;
|
||||||
|
case 't': current << '\t'; break;
|
||||||
|
default: current << c; break;
|
||||||
|
}
|
||||||
|
escape = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (c == '\\' && inString) { escape = true; continue; }
|
||||||
|
if (c == '\"') {
|
||||||
|
if (inString) {
|
||||||
|
values.push_back(current.str());
|
||||||
|
current.str("");
|
||||||
|
current.clear();
|
||||||
|
}
|
||||||
|
inString = !inString;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (inString) current << c;
|
||||||
|
}
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::vector<float> parse_float_array(const std::string& json, const std::string& key) {
|
||||||
|
std::vector<float> values;
|
||||||
|
auto segment = find_array_segment(json, key);
|
||||||
|
if (!segment) return values;
|
||||||
|
const std::string& arr = *segment;
|
||||||
|
std::size_t pos = 0;
|
||||||
|
while (pos < arr.size()) {
|
||||||
|
while (pos < arr.size() && !std::isdigit(static_cast<unsigned char>(arr[pos])) && arr[pos] != '-' && arr[pos] != '+') ++pos;
|
||||||
|
if (pos >= arr.size()) break;
|
||||||
|
std::size_t end = pos;
|
||||||
|
while (end < arr.size() && (std::isdigit(static_cast<unsigned char>(arr[end])) || arr[end] == '.' || arr[end] == 'e' || arr[end] == 'E' || arr[end] == '+' || arr[end] == '-')) ++end;
|
||||||
|
try {
|
||||||
|
values.push_back(std::stof(arr.substr(pos, end - pos)));
|
||||||
|
} catch (...) {
|
||||||
|
// Skip unparsable token.
|
||||||
|
}
|
||||||
|
pos = end;
|
||||||
|
}
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string format_score(float score) {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
os << "{\"upserted\":" << count << ",\"status\":\"ok\"}";
|
os.setf(std::ios::fixed);
|
||||||
|
os << std::setprecision(3) << score;
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::string search_memory(const std::string& reqJson) {
|
struct ParsedItem {
|
||||||
std::string queryText = extract_string_field(reqJson, "text");
|
std::string id;
|
||||||
|
std::string text;
|
||||||
|
std::vector<std::string> tags;
|
||||||
|
std::vector<float> embedding;
|
||||||
|
std::string rawJson;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline std::vector<ParsedItem> parse_items(const std::string& json) {
|
||||||
|
std::vector<ParsedItem> items;
|
||||||
|
for (const auto& obj : parse_object_array(json, "items")) {
|
||||||
|
ParsedItem item;
|
||||||
|
item.rawJson = obj;
|
||||||
|
item.id = extract_string_field(obj, "id");
|
||||||
|
item.text = extract_string_field(obj, "text");
|
||||||
|
item.tags = parse_string_array(obj, "tags");
|
||||||
|
item.embedding = parse_float_array(obj, "embedding");
|
||||||
|
items.push_back(std::move(item));
|
||||||
|
}
|
||||||
|
return items;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SearchMatch {
|
||||||
|
std::string id;
|
||||||
|
float score;
|
||||||
|
std::optional<std::string> text;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline std::string serialize_matches(const std::vector<SearchMatch>& matches) {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
os << "{\"matches\":[";
|
os << "{\"matches\":[";
|
||||||
if (!queryText.empty()) {
|
for (std::size_t i = 0; i < matches.size(); ++i) {
|
||||||
os << "{\"id\":\"stub-memory-1\",\"score\":0.42,\"text\":\"" << json_escape(queryText) << "\"}";
|
const auto& match = matches[i];
|
||||||
|
if (i) os << ",";
|
||||||
|
os << "{\"id\":\"" << json_escape(match.id) << "\""
|
||||||
|
<< ",\"score\":" << format_score(match.score);
|
||||||
|
if (match.text) {
|
||||||
|
os << ",\"text\":\"" << json_escape(*match.text) << "\"";
|
||||||
|
}
|
||||||
|
os << "}";
|
||||||
}
|
}
|
||||||
os << "]}";
|
os << "]}";
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
inline std::string upsert_memory(const std::string& reqJson) {
|
||||||
|
const std::string nsName = detail::extract_string_field(reqJson, "namespace");
|
||||||
|
if (nsName.empty()) {
|
||||||
|
return detail::error_response("bad_request", "namespace is required");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto nsRow = detail::database().ensureNamespace(nsName);
|
||||||
|
if (!nsRow) {
|
||||||
|
return detail::error_response("internal_error", "failed to ensure namespace");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto items = detail::parse_items(reqJson);
|
||||||
|
if (items.empty()) {
|
||||||
|
return detail::error_response("bad_request", "items array must contain at least one entry");
|
||||||
|
}
|
||||||
|
|
||||||
|
PgDal& dal = detail::database();
|
||||||
|
const bool hasTx = dal.begin();
|
||||||
|
std::vector<std::string> ids;
|
||||||
|
ids.reserve(items.size());
|
||||||
|
|
||||||
|
try {
|
||||||
|
for (auto& parsed : items) {
|
||||||
|
ItemRow row;
|
||||||
|
row.id = parsed.id;
|
||||||
|
row.namespace_id = nsRow->id;
|
||||||
|
row.content_json = parsed.rawJson;
|
||||||
|
row.text = parsed.text.empty() ? std::optional<std::string>() : std::optional<std::string>(parsed.text);
|
||||||
|
row.tags = parsed.tags;
|
||||||
|
row.revision = 1;
|
||||||
|
|
||||||
|
const std::string itemId = dal.upsertItem(row);
|
||||||
|
ids.push_back(itemId);
|
||||||
|
|
||||||
|
if (!parsed.embedding.empty()) {
|
||||||
|
ChunkRow chunk;
|
||||||
|
chunk.item_id = itemId;
|
||||||
|
chunk.ord = 0;
|
||||||
|
chunk.text = parsed.text;
|
||||||
|
auto chunkIds = dal.upsertChunks({chunk});
|
||||||
|
|
||||||
|
EmbeddingRow emb;
|
||||||
|
emb.chunk_id = chunkIds.front();
|
||||||
|
emb.model = "stub-model";
|
||||||
|
emb.dim = static_cast<int>(parsed.embedding.size());
|
||||||
|
emb.vector = parsed.embedding;
|
||||||
|
dal.upsertEmbeddings({emb});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (hasTx) dal.commit();
|
||||||
|
} catch (const std::exception& ex) {
|
||||||
|
if (hasTx) dal.rollback();
|
||||||
|
return detail::error_response("internal_error", ex.what());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "{\"upserted\":" << ids.size();
|
||||||
|
if (!ids.empty()) {
|
||||||
|
os << ",\"ids\":[";
|
||||||
|
for (std::size_t i = 0; i < ids.size(); ++i) {
|
||||||
|
if (i) os << ",";
|
||||||
|
os << "\"" << detail::json_escape(ids[i]) << "\"";
|
||||||
|
}
|
||||||
|
os << "]";
|
||||||
|
}
|
||||||
|
os << ",\"status\":\"ok\"}";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string search_memory(const std::string& reqJson) {
|
||||||
|
const std::string nsName = detail::extract_string_field(reqJson, "namespace");
|
||||||
|
if (nsName.empty()) {
|
||||||
|
return detail::error_response("bad_request", "namespace is required");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto nsRow = detail::database().findNamespace(nsName);
|
||||||
|
if (!nsRow) {
|
||||||
|
return "{\"matches\":[]}";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string queryText;
|
||||||
|
std::vector<float> queryEmbedding;
|
||||||
|
int limit = 5;
|
||||||
|
|
||||||
|
if (auto queryObj = detail::find_object_segment(reqJson, "query")) {
|
||||||
|
queryText = detail::extract_string_field(*queryObj, "text");
|
||||||
|
queryEmbedding = detail::parse_float_array(*queryObj, "embedding");
|
||||||
|
if (auto k = detail::extract_int_field(*queryObj, "k")) {
|
||||||
|
if (*k > 0) limit = *k;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PgDal& dal = detail::database();
|
||||||
|
std::unordered_set<std::string> seen;
|
||||||
|
std::vector<detail::SearchMatch> matches;
|
||||||
|
|
||||||
|
auto textRows = dal.searchText(nsRow->id, queryText, limit);
|
||||||
|
for (std::size_t idx = 0; idx < textRows.size(); ++idx) {
|
||||||
|
const auto& row = textRows[idx];
|
||||||
|
detail::SearchMatch match;
|
||||||
|
match.id = row.id;
|
||||||
|
match.text = row.text;
|
||||||
|
match.score = 1.0f - static_cast<float>(idx) * 0.05f;
|
||||||
|
matches.push_back(match);
|
||||||
|
seen.insert(match.id);
|
||||||
|
if (static_cast<int>(matches.size()) >= limit) break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (static_cast<int>(matches.size()) < limit && !queryEmbedding.empty()) {
|
||||||
|
auto vectorMatches = dal.searchVector(nsRow->id, queryEmbedding, limit);
|
||||||
|
for (const auto& pair : vectorMatches) {
|
||||||
|
if (seen.count(pair.first)) continue;
|
||||||
|
auto item = dal.getItemById(pair.first);
|
||||||
|
if (!item) continue;
|
||||||
|
detail::SearchMatch match;
|
||||||
|
match.id = pair.first;
|
||||||
|
match.score = pair.second;
|
||||||
|
match.text = item->text;
|
||||||
|
matches.push_back(match);
|
||||||
|
if (static_cast<int>(matches.size()) >= limit) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return detail::serialize_matches(matches);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace Handlers
|
} // namespace Handlers
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue