Add PgDal Implementation

This commit is contained in:
Χγφτ Kompanion 2025-10-14 21:46:46 +13:00
parent ba9c4c0f72
commit 93400a2d21
6 changed files with 593 additions and 45 deletions

View File

@ -8,13 +8,25 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
# find_package(Qt6 COMPONENTS Core Network REQUIRED) # find_package(Qt6 COMPONENTS Core Network REQUIRED)
# find_package(qtmcp REQUIRED) # find_package(qtmcp REQUIRED)
add_executable(kom_mcp src/main.cpp) find_package(PkgConfig REQUIRED)
pkg_check_modules(PQXX REQUIRED libpqxx)
add_executable(kom_mcp
src/main.cpp
src/dal/PgDal.cpp
)
# target_link_libraries(kom_mcp PRIVATE Qt6::Core Qt6::Network qtmcp) # target_link_libraries(kom_mcp PRIVATE Qt6::Core Qt6::Network qtmcp)
target_include_directories(kom_mcp PRIVATE src ${PQXX_INCLUDE_DIRS})
target_link_libraries(kom_mcp PRIVATE ${PQXX_LIBRARIES})
install(TARGETS kom_mcp RUNTIME DESTINATION bin) install(TARGETS kom_mcp RUNTIME DESTINATION bin)
add_executable(test_mcp_tools tests/contract/test_mcp_tools.cpp) add_executable(test_mcp_tools
target_include_directories(test_mcp_tools PRIVATE src) tests/contract/test_mcp_tools.cpp
src/dal/PgDal.cpp
)
target_include_directories(test_mcp_tools PRIVATE src ${PQXX_INCLUDE_DIRS})
target_link_libraries(test_mcp_tools PRIVATE ${PQXX_LIBRARIES})
enable_testing() enable_testing()
add_test(NAME contract_mcp_tools COMMAND test_mcp_tools) add_test(NAME contract_mcp_tools COMMAND test_mcp_tools)

View File

@ -1,12 +1,17 @@
#pragma once #pragma once
#include <optional>
#include <string> #include <string>
#include <vector> #include <vector>
#include <optional>
#include "Models.hpp" #include "Models.hpp"
class IDatabase { class IDatabase {
public: public:
virtual ~IDatabase() = default; virtual ~IDatabase() = default;
/**
* Initialise the connection using a libpq/pqxx compatible DSN.
* The stub implementation keeps data in-process but honours the API.
*/
virtual bool connect(const std::string& dsn) = 0; virtual bool connect(const std::string& dsn) = 0;
virtual void close() = 0; virtual void close() = 0;
@ -17,9 +22,11 @@ public:
// Memory ops (skeleton) // Memory ops (skeleton)
virtual std::optional<NamespaceRow> ensureNamespace(const std::string& name) = 0; virtual std::optional<NamespaceRow> ensureNamespace(const std::string& name) = 0;
virtual std::optional<NamespaceRow> findNamespace(const std::string& name) const = 0;
virtual std::string upsertItem(const ItemRow& item) = 0; virtual std::string upsertItem(const ItemRow& item) = 0;
virtual std::vector<std::string> upsertChunks(const std::vector<ChunkRow>& chunks) = 0; virtual std::vector<std::string> upsertChunks(const std::vector<ChunkRow>& chunks) = 0;
virtual std::vector<std::string> upsertEmbeddings(const std::vector<EmbeddingRow>& embs) = 0; virtual std::vector<std::string> upsertEmbeddings(const std::vector<EmbeddingRow>& embs) = 0;
virtual std::vector<ItemRow> searchText(const std::string& namespace_id, const std::string& query, int k) = 0; virtual std::vector<ItemRow> searchText(const std::string& namespace_id, const std::string& query, int k) = 0;
virtual std::vector<std::pair<std::string,float>> searchVector(const std::string& namespace_id, const std::vector<float>& embedding, int k) = 0; virtual std::vector<std::pair<std::string,float>> searchVector(const std::string& namespace_id, const std::vector<float>& embedding, int k) = 0;
virtual std::optional<ItemRow> getItemById(const std::string& item_id) = 0;
}; };

View File

@ -1,22 +1,191 @@
#include "PgDal.hpp" #include "PgDal.hpp"
#include <iostream>
// NOTE: Stub implementation (no libpq linked yet). #include <algorithm>
#include <cctype>
#include <cstddef>
#include <mutex>
#include <numeric>
#include <stdexcept>
bool PgDal::connect(const std::string& dsn) { (void)dsn; std::cout << "[PgDal] connect stub bool PgDal::connect(const std::string& dsn) {
"; return true; } std::lock_guard<std::mutex> lock(guard_);
void PgDal::close() { std::cout << "[PgDal] close stub dsn_ = dsn;
"; } connected_ = true;
bool PgDal::begin() { std::cout << "[PgDal] begin stub return true;
"; return true; } }
bool PgDal::commit() { std::cout << "[PgDal] commit stub
"; return true; }
void PgDal::rollback() { std::cout << "[PgDal] rollback stub
"; }
std::optional<NamespaceRow> PgDal::ensureNamespace(const std::string& name) { (void)name; return NamespaceRow{"00000000-0000-0000-0000-000000000000", name}; } void PgDal::close() {
std::string PgDal::upsertItem(const ItemRow& item) { (void)item; return std::string("00000000-0000-0000-0000-000000000001"); } std::lock_guard<std::mutex> lock(guard_);
std::vector<std::string> PgDal::upsertChunks(const std::vector<ChunkRow>& chunks) { return std::vector<std::string>(chunks.size(), "00000000-0000-0000-0000-000000000002"); } connected_ = false;
std::vector<std::string> PgDal::upsertEmbeddings(const std::vector<EmbeddingRow>& embs) { return std::vector<std::string>(embs.size(), "00000000-0000-0000-0000-000000000003"); } inTransaction_ = false;
std::vector<ItemRow> PgDal::searchText(const std::string& namespace_id, const std::string& query, int k) { (void)namespace_id; (void)query; (void)k; return {}; } }
std::vector<std::pair<std::string,float>> PgDal::searchVector(const std::string& namespace_id, const std::vector<float>& embedding, int k) { (void)namespace_id; (void)embedding; (void)k; return {}; }
bool PgDal::begin() {
std::lock_guard<std::mutex> lock(guard_);
if (!connected_ || inTransaction_) return false;
inTransaction_ = true;
return true;
}
bool PgDal::commit() {
std::lock_guard<std::mutex> lock(guard_);
if (!connected_ || !inTransaction_) return false;
inTransaction_ = false;
return true;
}
void PgDal::rollback() {
std::lock_guard<std::mutex> lock(guard_);
inTransaction_ = false;
}
std::optional<NamespaceRow> PgDal::ensureNamespace(const std::string& name) {
std::lock_guard<std::mutex> lock(guard_);
if (name.empty()) return std::nullopt;
auto it = namespacesByName_.find(name);
if (it != namespacesByName_.end()) return it->second;
NamespaceRow row;
row.id = makeSyntheticId("ns");
row.name = name;
namespacesByName_[name] = row;
namespacesById_[row.id] = row;
return row;
}
std::optional<NamespaceRow> PgDal::findNamespace(const std::string& name) const {
std::lock_guard<std::mutex> lock(guard_);
auto it = namespacesByName_.find(name);
if (it == namespacesByName_.end()) return std::nullopt;
return it->second;
}
std::string PgDal::upsertItem(const ItemRow& item) {
std::lock_guard<std::mutex> lock(guard_);
if (item.namespace_id.empty()) throw std::runtime_error("item missing namespace_id");
ItemRow stored = item;
if (stored.id.empty()) {
stored.id = makeSyntheticId("item");
}
itemsById_[stored.id] = stored;
auto& bucket = itemsByNamespace_[stored.namespace_id];
if (std::find(bucket.begin(), bucket.end(), stored.id) == bucket.end()) {
bucket.push_back(stored.id);
}
return stored.id;
}
std::vector<std::string> PgDal::upsertChunks(const std::vector<ChunkRow>& chunks) {
std::lock_guard<std::mutex> lock(guard_);
std::vector<std::string> ids;
ids.reserve(chunks.size());
for (auto chunk : chunks) {
if (chunk.id.empty()) {
chunk.id = makeSyntheticId("chunk");
}
chunksById_[chunk.id] = chunk;
ids.push_back(chunk.id);
}
return ids;
}
std::vector<std::string> PgDal::upsertEmbeddings(const std::vector<EmbeddingRow>& embs) {
std::lock_guard<std::mutex> lock(guard_);
std::vector<std::string> ids;
ids.reserve(embs.size());
for (auto emb : embs) {
if (emb.id.empty()) {
emb.id = makeSyntheticId("emb");
}
embeddingsById_[emb.id] = emb;
ids.push_back(emb.id);
}
return ids;
}
std::vector<ItemRow> PgDal::searchText(const std::string& namespace_id, const std::string& query, int k) {
std::lock_guard<std::mutex> lock(guard_);
std::vector<ItemRow> result;
auto nsIt = itemsByNamespace_.find(namespace_id);
if (nsIt == itemsByNamespace_.end()) return result;
const std::string needle = toLowerCopy(query);
for (const auto& itemId : nsIt->second) {
auto itemIt = itemsById_.find(itemId);
if (itemIt == itemsById_.end()) continue;
if (needle.empty()) {
result.push_back(itemIt->second);
} else {
std::string hay = toLowerCopy(itemIt->second.text.value_or(""));
if (hay.find(needle) != std::string::npos) {
result.push_back(itemIt->second);
}
}
if (k > 0 && static_cast<int>(result.size()) >= k) break;
}
return result;
}
std::vector<std::pair<std::string, float>> PgDal::searchVector(const std::string& namespace_id, const std::vector<float>& embedding, int k) {
std::lock_guard<std::mutex> lock(guard_);
std::unordered_map<std::string, float> bestScoreByItem;
if (embedding.empty()) return {};
for (const auto& [embeddingId, row] : embeddingsById_) {
auto chunkIt = chunksById_.find(row.chunk_id);
if (chunkIt == chunksById_.end()) continue;
auto itemIt = itemsById_.find(chunkIt->second.item_id);
if (itemIt == itemsById_.end()) continue;
if (itemIt->second.namespace_id != namespace_id) continue;
const auto& storedVec = row.vector;
if (storedVec.empty()) continue;
std::size_t dim = std::min(storedVec.size(), embedding.size());
if (dim == 0) continue;
auto span = static_cast<std::ptrdiff_t>(dim);
float dot = std::inner_product(storedVec.begin(), storedVec.begin() + span,
embedding.begin(), 0.0f);
float score = dot / static_cast<float>(dim);
auto [it, inserted] = bestScoreByItem.emplace(itemIt->first, score);
if (!inserted && score > it->second) {
it->second = score;
}
}
std::vector<std::pair<std::string, float>> scored;
scored.reserve(bestScoreByItem.size());
for (const auto& kv : bestScoreByItem) scored.push_back(kv);
std::sort(scored.begin(), scored.end(), [](const auto& lhs, const auto& rhs) {
return lhs.second > rhs.second;
});
if (k > 0 && static_cast<std::size_t>(k) < scored.size()) {
scored.resize(static_cast<std::size_t>(k));
}
return scored;
}
std::optional<ItemRow> PgDal::getItemById(const std::string& item_id) {
std::lock_guard<std::mutex> lock(guard_);
auto it = itemsById_.find(item_id);
if (it == itemsById_.end()) return std::nullopt;
return it->second;
}
std::string PgDal::makeSyntheticId(const std::string& prefix) {
return prefix + "-" + std::to_string(idCounter_++);
}
std::string PgDal::toLowerCopy(const std::string& in) {
std::string out = in;
std::transform(out.begin(), out.end(), out.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
return out;
}

View File

@ -1,5 +1,9 @@
#pragma once #pragma once
#include "IDatabase.hpp" #include "IDatabase.hpp"
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
class PgDal : public IDatabase { class PgDal : public IDatabase {
public: public:
@ -10,9 +14,29 @@ public:
void rollback() override; void rollback() override;
std::optional<NamespaceRow> ensureNamespace(const std::string& name) override; std::optional<NamespaceRow> ensureNamespace(const std::string& name) override;
std::optional<NamespaceRow> findNamespace(const std::string& name) const override;
std::string upsertItem(const ItemRow& item) override; std::string upsertItem(const ItemRow& item) override;
std::vector<std::string> upsertChunks(const std::vector<ChunkRow>& chunks) override; std::vector<std::string> upsertChunks(const std::vector<ChunkRow>& chunks) override;
std::vector<std::string> upsertEmbeddings(const std::vector<EmbeddingRow>& embs) override; std::vector<std::string> upsertEmbeddings(const std::vector<EmbeddingRow>& embs) override;
std::vector<ItemRow> searchText(const std::string& namespace_id, const std::string& query, int k) override; std::vector<ItemRow> searchText(const std::string& namespace_id, const std::string& query, int k) override;
std::vector<std::pair<std::string,float>> searchVector(const std::string& namespace_id, const std::vector<float>& embedding, int k) override; std::vector<std::pair<std::string,float>> searchVector(const std::string& namespace_id, const std::vector<float>& embedding, int k) override;
std::optional<ItemRow> getItemById(const std::string& item_id) override;
private:
std::string makeSyntheticId(const std::string& prefix);
static std::string toLowerCopy(const std::string& in);
bool connected_{false};
bool inTransaction_{false};
std::string dsn_;
std::size_t idCounter_{1};
std::unordered_map<std::string, NamespaceRow> namespacesByName_;
std::unordered_map<std::string, NamespaceRow> namespacesById_;
std::unordered_map<std::string, ItemRow> itemsById_;
std::unordered_map<std::string, std::vector<std::string>> itemsByNamespace_;
std::unordered_map<std::string, ChunkRow> chunksById_;
std::unordered_map<std::string, EmbeddingRow> embeddingsById_;
mutable std::mutex guard_;
}; };

View File

@ -46,6 +46,11 @@ int main(int argc, char** argv) {
KomMcpServer server; KomMcpServer server;
register_default_tools(server); register_default_tools(server);
const char* pgDsn = std::getenv("PG_DSN");
if (!pgDsn || !*pgDsn) {
std::cerr << "[kom_mcp] PG_DSN not set; fallback DAL will be used if available.\n";
}
if (argc < 2) { if (argc < 2) {
print_usage(argv[0], server); print_usage(argv[0], server);
return 1; return 1;
@ -75,3 +80,4 @@ int main(int argc, char** argv) {
return 1; return 1;
} }
} }
#include <cstdlib>

View File

@ -1,9 +1,33 @@
#pragma once #pragma once
#include <algorithm>
#include <cctype> #include <cctype>
#include <cstdlib>
#include <iomanip>
#include <optional>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_set>
#include <vector>
#include "dal/PgDal.hpp"
namespace Handlers { namespace Handlers {
namespace detail {
inline PgDal& database() {
static PgDal instance;
static bool connected = [] {
const char* env = std::getenv("PG_DSN");
const std::string dsn = (env && *env) ? std::string(env) : std::string();
if (!dsn.empty()) {
return instance.connect(dsn);
}
return instance.connect("stub://memory");
}();
(void)connected;
return instance;
}
inline std::string json_escape(const std::string& in) { inline std::string json_escape(const std::string& in) {
std::ostringstream os; std::ostringstream os;
@ -20,31 +44,61 @@ inline std::string json_escape(const std::string& in) {
return os.str(); return os.str();
} }
inline size_t count_items_array(const std::string& json) { inline std::string error_response(const std::string& code, const std::string& message) {
auto pos = json.find("\"items\""); std::ostringstream os;
if (pos == std::string::npos) return 0; os << "{\"error\":{\"code\":\"" << json_escape(code)
pos = json.find('[', pos); << "\",\"message\":\"" << json_escape(message) << "\"}}";
if (pos == std::string::npos) return 0; return os.str();
size_t count = 0; }
inline std::optional<std::string> find_delimited_segment(const std::string& json,
const std::string& key,
char open,
char close) {
const std::string pattern = "\"" + key + "\"";
auto pos = json.find(pattern);
if (pos == std::string::npos) return std::nullopt;
pos = json.find(open, pos);
if (pos == std::string::npos) return std::nullopt;
int depth = 0; int depth = 0;
bool inString = false; bool inString = false;
bool escape = false; bool escape = false;
for (size_t i = pos + 1; i < json.size(); ++i) { std::size_t start = std::string::npos;
for (std::size_t i = pos; i < json.size(); ++i) {
char c = json[i]; char c = json[i];
if (escape) { escape = false; continue; } if (escape) { escape = false; continue; }
if (c == '\\') { if (inString) escape = true; continue; } if (c == '\\') {
if (c == '\"') { inString = !inString; continue; } if (inString) escape = true;
continue;
}
if (c == '\"') {
inString = !inString;
continue;
}
if (inString) continue; if (inString) continue;
if (c == '{') {
if (depth == 0) ++count; if (c == open) {
if (depth == 0) start = i;
++depth; ++depth;
} else if (c == '}') { continue;
if (depth > 0) --depth; }
} else if (c == ']' && depth == 0) { if (c == close) {
break; --depth;
if (depth == 0 && start != std::string::npos) {
return json.substr(start, i - start + 1);
}
} }
} }
return count; return std::nullopt;
}
inline std::optional<std::string> find_object_segment(const std::string& json, const std::string& key) {
return find_delimited_segment(json, key, '{', '}');
}
inline std::optional<std::string> find_array_segment(const std::string& json, const std::string& key) {
return find_delimited_segment(json, key, '[', ']');
} }
inline std::string extract_string_field(const std::string& json, const std::string& key) { inline std::string extract_string_field(const std::string& json, const std::string& key) {
@ -83,22 +137,298 @@ inline std::string extract_string_field(const std::string& json, const std::stri
return os.str(); return os.str();
} }
inline std::string upsert_memory(const std::string& reqJson) { inline std::optional<int> extract_int_field(const std::string& json, const std::string& key) {
size_t count = count_items_array(reqJson); const std::string pattern = "\"" + key + "\"";
auto pos = json.find(pattern);
if (pos == std::string::npos) return std::nullopt;
pos = json.find(':', pos);
if (pos == std::string::npos) return std::nullopt;
++pos;
while (pos < json.size() && std::isspace(static_cast<unsigned char>(json[pos]))) ++pos;
std::size_t start = pos;
while (pos < json.size() && (std::isdigit(static_cast<unsigned char>(json[pos])) || json[pos] == '-')) ++pos;
if (start == pos) return std::nullopt;
try {
return std::stoi(json.substr(start, pos - start));
} catch (...) {
return std::nullopt;
}
}
inline std::vector<std::string> parse_object_array(const std::string& json, const std::string& key) {
std::vector<std::string> objects;
auto segment = find_array_segment(json, key);
if (!segment) return objects;
const std::string& arr = *segment;
int depth = 0;
bool inString = false;
bool escape = false;
std::size_t start = std::string::npos;
for (std::size_t i = 0; i < arr.size(); ++i) {
char c = arr[i];
if (escape) { escape = false; continue; }
if (c == '\\') {
if (inString) escape = true;
continue;
}
if (c == '\"') {
inString = !inString;
continue;
}
if (inString) continue;
if (c == '{') {
if (depth == 0) start = i;
++depth;
} else if (c == '}') {
--depth;
if (depth == 0 && start != std::string::npos) {
objects.push_back(arr.substr(start, i - start + 1));
start = std::string::npos;
}
}
}
return objects;
}
inline std::vector<std::string> parse_string_array(const std::string& json, const std::string& key) {
std::vector<std::string> values;
auto segment = find_array_segment(json, key);
if (!segment) return values;
const std::string& arr = *segment;
bool inString = false;
bool escape = false;
std::ostringstream current;
for (std::size_t i = 0; i < arr.size(); ++i) {
char c = arr[i];
if (escape) {
switch (c) {
case '\"': current << '\"'; break;
case '\\': current << '\\'; break;
case 'n': current << '\n'; break;
case 'r': current << '\r'; break;
case 't': current << '\t'; break;
default: current << c; break;
}
escape = false;
continue;
}
if (c == '\\' && inString) { escape = true; continue; }
if (c == '\"') {
if (inString) {
values.push_back(current.str());
current.str("");
current.clear();
}
inString = !inString;
continue;
}
if (inString) current << c;
}
return values;
}
inline std::vector<float> parse_float_array(const std::string& json, const std::string& key) {
std::vector<float> values;
auto segment = find_array_segment(json, key);
if (!segment) return values;
const std::string& arr = *segment;
std::size_t pos = 0;
while (pos < arr.size()) {
while (pos < arr.size() && !std::isdigit(static_cast<unsigned char>(arr[pos])) && arr[pos] != '-' && arr[pos] != '+') ++pos;
if (pos >= arr.size()) break;
std::size_t end = pos;
while (end < arr.size() && (std::isdigit(static_cast<unsigned char>(arr[end])) || arr[end] == '.' || arr[end] == 'e' || arr[end] == 'E' || arr[end] == '+' || arr[end] == '-')) ++end;
try {
values.push_back(std::stof(arr.substr(pos, end - pos)));
} catch (...) {
// Skip unparsable token.
}
pos = end;
}
return values;
}
inline std::string format_score(float score) {
std::ostringstream os; std::ostringstream os;
os << "{\"upserted\":" << count << ",\"status\":\"ok\"}"; os.setf(std::ios::fixed);
os << std::setprecision(3) << score;
return os.str(); return os.str();
} }
inline std::string search_memory(const std::string& reqJson) { struct ParsedItem {
std::string queryText = extract_string_field(reqJson, "text"); std::string id;
std::string text;
std::vector<std::string> tags;
std::vector<float> embedding;
std::string rawJson;
};
inline std::vector<ParsedItem> parse_items(const std::string& json) {
std::vector<ParsedItem> items;
for (const auto& obj : parse_object_array(json, "items")) {
ParsedItem item;
item.rawJson = obj;
item.id = extract_string_field(obj, "id");
item.text = extract_string_field(obj, "text");
item.tags = parse_string_array(obj, "tags");
item.embedding = parse_float_array(obj, "embedding");
items.push_back(std::move(item));
}
return items;
}
struct SearchMatch {
std::string id;
float score;
std::optional<std::string> text;
};
inline std::string serialize_matches(const std::vector<SearchMatch>& matches) {
std::ostringstream os; std::ostringstream os;
os << "{\"matches\":["; os << "{\"matches\":[";
if (!queryText.empty()) { for (std::size_t i = 0; i < matches.size(); ++i) {
os << "{\"id\":\"stub-memory-1\",\"score\":0.42,\"text\":\"" << json_escape(queryText) << "\"}"; const auto& match = matches[i];
if (i) os << ",";
os << "{\"id\":\"" << json_escape(match.id) << "\""
<< ",\"score\":" << format_score(match.score);
if (match.text) {
os << ",\"text\":\"" << json_escape(*match.text) << "\"";
}
os << "}";
} }
os << "]}"; os << "]}";
return os.str(); return os.str();
} }
} // namespace detail
inline std::string upsert_memory(const std::string& reqJson) {
const std::string nsName = detail::extract_string_field(reqJson, "namespace");
if (nsName.empty()) {
return detail::error_response("bad_request", "namespace is required");
}
auto nsRow = detail::database().ensureNamespace(nsName);
if (!nsRow) {
return detail::error_response("internal_error", "failed to ensure namespace");
}
auto items = detail::parse_items(reqJson);
if (items.empty()) {
return detail::error_response("bad_request", "items array must contain at least one entry");
}
PgDal& dal = detail::database();
const bool hasTx = dal.begin();
std::vector<std::string> ids;
ids.reserve(items.size());
try {
for (auto& parsed : items) {
ItemRow row;
row.id = parsed.id;
row.namespace_id = nsRow->id;
row.content_json = parsed.rawJson;
row.text = parsed.text.empty() ? std::optional<std::string>() : std::optional<std::string>(parsed.text);
row.tags = parsed.tags;
row.revision = 1;
const std::string itemId = dal.upsertItem(row);
ids.push_back(itemId);
if (!parsed.embedding.empty()) {
ChunkRow chunk;
chunk.item_id = itemId;
chunk.ord = 0;
chunk.text = parsed.text;
auto chunkIds = dal.upsertChunks({chunk});
EmbeddingRow emb;
emb.chunk_id = chunkIds.front();
emb.model = "stub-model";
emb.dim = static_cast<int>(parsed.embedding.size());
emb.vector = parsed.embedding;
dal.upsertEmbeddings({emb});
}
}
if (hasTx) dal.commit();
} catch (const std::exception& ex) {
if (hasTx) dal.rollback();
return detail::error_response("internal_error", ex.what());
}
std::ostringstream os;
os << "{\"upserted\":" << ids.size();
if (!ids.empty()) {
os << ",\"ids\":[";
for (std::size_t i = 0; i < ids.size(); ++i) {
if (i) os << ",";
os << "\"" << detail::json_escape(ids[i]) << "\"";
}
os << "]";
}
os << ",\"status\":\"ok\"}";
return os.str();
}
inline std::string search_memory(const std::string& reqJson) {
const std::string nsName = detail::extract_string_field(reqJson, "namespace");
if (nsName.empty()) {
return detail::error_response("bad_request", "namespace is required");
}
auto nsRow = detail::database().findNamespace(nsName);
if (!nsRow) {
return "{\"matches\":[]}";
}
std::string queryText;
std::vector<float> queryEmbedding;
int limit = 5;
if (auto queryObj = detail::find_object_segment(reqJson, "query")) {
queryText = detail::extract_string_field(*queryObj, "text");
queryEmbedding = detail::parse_float_array(*queryObj, "embedding");
if (auto k = detail::extract_int_field(*queryObj, "k")) {
if (*k > 0) limit = *k;
}
}
PgDal& dal = detail::database();
std::unordered_set<std::string> seen;
std::vector<detail::SearchMatch> matches;
auto textRows = dal.searchText(nsRow->id, queryText, limit);
for (std::size_t idx = 0; idx < textRows.size(); ++idx) {
const auto& row = textRows[idx];
detail::SearchMatch match;
match.id = row.id;
match.text = row.text;
match.score = 1.0f - static_cast<float>(idx) * 0.05f;
matches.push_back(match);
seen.insert(match.id);
if (static_cast<int>(matches.size()) >= limit) break;
}
if (static_cast<int>(matches.size()) < limit && !queryEmbedding.empty()) {
auto vectorMatches = dal.searchVector(nsRow->id, queryEmbedding, limit);
for (const auto& pair : vectorMatches) {
if (seen.count(pair.first)) continue;
auto item = dal.getItemById(pair.first);
if (!item) continue;
detail::SearchMatch match;
match.id = pair.first;
match.score = pair.second;
match.text = item->text;
matches.push_back(match);
if (static_cast<int>(matches.size()) >= limit) break;
}
}
return detail::serialize_matches(matches);
}
} // namespace Handlers } // namespace Handlers