metal-kompanion/src/cli/KompanionApp.cpp

810 lines
24 KiB
C++

#include <QCommandLineOption>
#include <QCommandLineParser>
#include <QCoreApplication>
#include <QDir>
#include <QFile>
#include <QFileInfo>
#include <QTextStream>
#include <QUrl>
#include <QUrlQuery>
#include <QProcess>
#include <QRandomGenerator>
#include <QByteArray>
#include <QStandardPaths>
#include <QSqlDatabase>
#include <QSqlDriver>
#include <QSqlError>
#include <QSqlQuery>
#ifdef HAVE_KCONFIG
#include <KConfigGroup>
#include <KSharedConfig>
#else
#include <QSettings>
#endif
#include <algorithm>
#include <cstdlib>
#include <cstdio>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <limits>
#include <optional>
#include <sstream>
#include <string>
#include <vector>
#include "mcp/KomMcpServer.hpp"
#include "mcp/RegisterTools.hpp"
namespace {
const std::filesystem::path& projectRoot() {
static const std::filesystem::path root =
#ifdef PROJECT_SOURCE_DIR
std::filesystem::path(PROJECT_SOURCE_DIR);
#else
std::filesystem::current_path();
#endif
return root;
}
const std::filesystem::path& installedSchemaDir() {
#ifdef KOMPANION_DB_INIT_INSTALL_DIR
static const std::filesystem::path dir(KOMPANION_DB_INIT_INSTALL_DIR);
#else
static const std::filesystem::path dir;
#endif
return dir;
}
std::vector<std::filesystem::path> schemaDirectories() {
std::vector<std::filesystem::path> dirs;
const auto& installDir = installedSchemaDir();
if (!installDir.empty() && std::filesystem::exists(installDir)) {
dirs.push_back(installDir);
}
const auto sourceDir = projectRoot() / "db" / "init";
if (std::filesystem::exists(sourceDir)) {
dirs.push_back(sourceDir);
}
return dirs;
}
std::vector<std::filesystem::path> collectSchemaFiles() {
std::vector<std::filesystem::path> files;
for (const auto& dir : schemaDirectories()) {
for (const auto& entry : std::filesystem::directory_iterator(dir)) {
if (!entry.is_regular_file()) continue;
if (entry.path().extension() == ".sql") {
files.push_back(entry.path());
}
}
}
std::sort(files.begin(), files.end());
return files;
}
std::string readAll(std::istream& in) {
std::ostringstream oss;
oss << in.rdbuf();
return oss.str();
}
#ifndef HAVE_KCONFIG
QString configFilePath() {
QString base = QStandardPaths::writableLocation(QStandardPaths::ConfigLocation);
if (base.isEmpty()) {
base = QDir::homePath();
}
QDir dir(base);
return dir.filePath(QStringLiteral("kompanionrc"));
}
#endif
std::optional<std::string> readDsnFromConfig() {
#ifdef HAVE_KCONFIG
auto config = KSharedConfig::openConfig(QStringLiteral("kompanionrc"));
if (!config) return std::nullopt;
KConfigGroup dbGroup(config, QStringLiteral("Database"));
const QString entry = dbGroup.readEntry(QStringLiteral("PgDsn"), QString());
if (entry.isEmpty()) return std::nullopt;
return entry.toStdString();
#else
QSettings settings(configFilePath(), QSettings::IniFormat);
const QString entry = settings.value(QStringLiteral("Database/PgDsn")).toString();
if (entry.isEmpty()) return std::nullopt;
return entry.toStdString();
#endif
}
void writeDsnToConfig(const std::string& dsn) {
#ifdef HAVE_KCONFIG
auto config = KSharedConfig::openConfig(QStringLiteral("kompanionrc"));
KConfigGroup dbGroup(config, QStringLiteral("Database"));
dbGroup.writeEntry(QStringLiteral("PgDsn"), QString::fromStdString(dsn));
config->sync();
#else
QSettings settings(configFilePath(), QSettings::IniFormat);
settings.beginGroup(QStringLiteral("Database"));
settings.setValue(QStringLiteral("PgDsn"), QString::fromStdString(dsn));
settings.endGroup();
settings.sync();
#endif
}
bool readFileUtf8(const QString& path, std::string& out, QString* error) {
QFile file(path);
if (!file.open(QIODevice::ReadOnly | QIODevice::Text)) {
if (error) {
*error = QStringLiteral("Unable to open request file: %1").arg(path);
}
return false;
}
const QByteArray data = file.readAll();
out = QString::fromUtf8(data).toStdString();
return true;
}
bool looksLikeFile(const QString& value) {
QFileInfo info(value);
return info.exists() && info.isFile();
}
QString promptWithDefault(QTextStream& in,
QTextStream& out,
const QString& label,
const QString& def,
bool secret = false) {
out << label;
if (!def.isEmpty()) {
out << " [" << def << "]";
}
out << ": " << Qt::flush;
QString line = in.readLine();
if (line.isNull()) {
return def;
}
if (line.trimmed().isEmpty()) {
return def;
}
if (secret) {
out << "\n";
}
return line.trimmed();
}
bool promptYesNo(QTextStream& in,
QTextStream& out,
const QString& question,
bool defaultYes) {
out << question << (defaultYes ? " [Y/n]: " : " [y/N]: ") << Qt::flush;
QString line = in.readLine();
if (line.isNull() || line.trimmed().isEmpty()) {
return defaultYes;
}
const QString lower = line.trimmed().toLower();
if (lower == "y" || lower == "yes") return true;
if (lower == "n" || lower == "no") return false;
return defaultYes;
}
struct ConnectionConfig {
QString host = QStringLiteral("localhost");
QString port = QStringLiteral("5432");
QString dbname = QStringLiteral("kompanion");
QString user = [] {
const QByteArray env = qgetenv("USER");
return env.isEmpty() ? QStringLiteral("kompanion")
: QString::fromLocal8Bit(env);
}();
QString password = QStringLiteral("komup");
bool useSocket = false;
QString socketPath = QStringLiteral("/var/run/postgresql");
QString options;
};
ConnectionConfig configFromDsn(const std::optional<std::string>& dsn) {
ConnectionConfig cfg;
if (!dsn) return cfg;
const QUrl url(QString::fromStdString(*dsn));
if (!url.host().isEmpty()) cfg.host = url.host();
if (url.port() > 0) cfg.port = QString::number(url.port());
if (!url.userName().isEmpty()) cfg.user = url.userName();
if (!url.password().isEmpty()) cfg.password = url.password();
if (!url.path().isEmpty()) cfg.dbname = url.path().mid(1);
const QUrlQuery query(url);
if (query.hasQueryItem(QStringLiteral("host")) &&
query.queryItemValue(QStringLiteral("host")).startsWith('/')) {
cfg.useSocket = true;
cfg.socketPath = query.queryItemValue(QStringLiteral("host"));
}
return cfg;
}
std::string buildDsn(const ConnectionConfig& cfg) {
QUrl url;
url.setScheme(QStringLiteral("postgresql"));
url.setUserName(cfg.user);
url.setPassword(cfg.password);
if (cfg.useSocket) {
QUrlQuery query;
query.addQueryItem(QStringLiteral("host"), cfg.socketPath);
url.setQuery(query);
} else {
url.setHost(cfg.host);
bool ok = false;
int port = cfg.port.toInt(&ok);
if (ok && port > 0) {
url.setPort(port);
}
}
url.setPath(QStringLiteral("/") + cfg.dbname);
return url.toString(QUrl::FullyEncoded).toStdString();
}
QString detectSocketPath() {
const QStringList candidates{
QStringLiteral("/var/run/postgresql"),
QStringLiteral("/tmp")
};
for (const QString& candidate : candidates) {
QFileInfo info(candidate);
if (info.exists() && info.isDir()) {
return candidate;
}
}
return {};
}
QStringList listDatabasesOwnedByCurrentUser() {
QProcess proc;
QStringList args{QStringLiteral("-At"), QStringLiteral("-c"),
QStringLiteral("SELECT datname FROM pg_database WHERE datistemplate = false AND pg_get_userbyid(datdba) = current_user;")};
proc.start(QStringLiteral("psql"), args);
if (!proc.waitForFinished(2000) || proc.exitStatus() != QProcess::NormalExit || proc.exitCode() != 0) {
return {};
}
const QString output = QString::fromUtf8(proc.readAllStandardOutput());
QStringList lines = output.split(QLatin1Char('\n'), Qt::SkipEmptyParts);
for (QString& line : lines) {
line = line.trimmed();
}
lines.removeAll(QString());
return lines;
}
bool testConnection(const std::string& dsn, QString* error = nullptr) {
if (!QSqlDatabase::isDriverAvailable(QStringLiteral("QPSQL"))) {
if (error) *error = QStringLiteral("QPSQL driver not available");
return false;
}
const QString connName = QStringLiteral("kompanion_check_%1")
.arg(QRandomGenerator::global()->generate64(), 0, 16);
QSqlDatabase db = QSqlDatabase::addDatabase(QStringLiteral("QPSQL"), connName);
const auto cfg = configFromDsn(std::optional<std::string>(dsn));
db.setDatabaseName(cfg.dbname);
if (!cfg.user.isEmpty()) db.setUserName(cfg.user);
if (!cfg.password.isEmpty()) db.setPassword(cfg.password);
if (cfg.useSocket) {
db.setHostName(cfg.socketPath);
} else {
db.setHostName(cfg.host);
}
bool portOk = false;
const int portValue = cfg.port.toInt(&portOk);
if (portOk && portValue > 0) {
db.setPort(portValue);
}
if (!cfg.options.isEmpty()) db.setConnectOptions(cfg.options);
const bool opened = db.open();
if (!opened && error) {
*error = db.lastError().text();
}
db.close();
db = QSqlDatabase();
QSqlDatabase::removeDatabase(connName);
return opened;
}
bool schemaExists(QSqlDatabase& db, bool* exists, QString* error) {
QSqlQuery query(db);
if (!query.exec(QStringLiteral("SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname='public' AND tablename='memory_items')"))) {
if (error) {
*error = query.lastError().text();
}
return false;
}
if (!query.next()) {
if (exists) *exists = false;
return true;
}
if (exists) *exists = query.value(0).toBool();
return true;
}
bool applySchemaFiles(QSqlDatabase& db,
QTextStream& out,
bool verbose) {
const auto files = collectSchemaFiles();
if (files.empty()) {
out << "No schema files found in search paths.\n";
return false;
}
QSqlQuery query(db);
for (const auto& path : files) {
std::ifstream sqlFile(path);
if (!sqlFile) {
out << "Skipping unreadable schema file: " << QString::fromStdString(path.string()) << "\n";
continue;
}
std::ostringstream buffer;
buffer << sqlFile.rdbuf();
const QString sql = QString::fromUtf8(buffer.str().c_str());
if (!query.exec(sql)) {
out << "Error applying schema " << QString::fromStdString(path.filename().string())
<< ": " << query.lastError().text() << "\n";
return false;
}
if (verbose) {
out << "Applied schema: " << QString::fromStdString(path.filename().string()) << "\n";
}
}
return true;
}
bool ensureSchema(const std::string& dsn,
QTextStream& out,
bool verbose) {
if (!QSqlDatabase::isDriverAvailable(QStringLiteral("QPSQL"))) {
out << "QPSQL driver not available.\n";
return false;
}
const QString connName = QStringLiteral("kompanion_schema_%1")
.arg(QRandomGenerator::global()->generate64(), 0, 16);
QSqlDatabase db = QSqlDatabase::addDatabase(QStringLiteral("QPSQL"), connName);
const auto cfg = configFromDsn(std::optional<std::string>(dsn));
db.setDatabaseName(cfg.dbname);
if (!cfg.user.isEmpty()) db.setUserName(cfg.user);
if (!cfg.password.isEmpty()) db.setPassword(cfg.password);
if (cfg.useSocket) {
db.setHostName(cfg.socketPath);
} else {
db.setHostName(cfg.host);
}
bool portOk = false;
const int portValue = cfg.port.toInt(&portOk);
if (portOk && portValue > 0) {
db.setPort(portValue);
}
if (!cfg.options.isEmpty()) db.setConnectOptions(cfg.options);
if (!db.open()) {
out << "Failed to connect for schema application: " << db.lastError().text() << "\n";
QSqlDatabase::removeDatabase(connName);
return false;
}
bool exists = false;
QString err;
if (!schemaExists(db, &exists, &err)) {
out << "Failed to check schema: " << err << "\n";
db.close();
db = QSqlDatabase();
QSqlDatabase::removeDatabase(connName);
return false;
}
if (exists) {
if (verbose) out << "Schema already present.\n";
db.close();
db = QSqlDatabase();
QSqlDatabase::removeDatabase(connName);
return true;
}
out << "Schema not found; applying migrations...\n";
if (!applySchemaFiles(db, out, verbose)) {
out << "Schema application reported errors.\n";
db.close();
db = QSqlDatabase();
QSqlDatabase::removeDatabase(connName);
return false;
}
if (!schemaExists(db, &exists, &err) || !exists) {
out << "Schema still missing after applying migrations.\n";
if (!err.isEmpty()) {
out << "Last error: " << err << "\n";
}
db.close();
db = QSqlDatabase();
QSqlDatabase::removeDatabase(connName);
return false;
}
db.close();
db = QSqlDatabase();
QSqlDatabase::removeDatabase(connName);
out << "Schema initialized successfully.\n";
return true;
}
std::optional<std::string> autoDetectDsn() {
if (!QSqlDatabase::isDriverAvailable(QStringLiteral("QPSQL"))) {
return std::nullopt;
}
QStringList candidates;
if (const char* env = std::getenv("PG_DSN"); env && *env) {
candidates << QString::fromUtf8(env);
}
const QString socketPath = detectSocketPath();
QStringList owned = listDatabasesOwnedByCurrentUser();
QStringList ordered;
if (owned.contains(QStringLiteral("kompanion"))) {
ordered << QStringLiteral("kompanion");
owned.removeAll(QStringLiteral("kompanion"));
}
if (owned.contains(QStringLiteral("kompanion_test"))) {
ordered << QStringLiteral("kompanion_test");
owned.removeAll(QStringLiteral("kompanion_test"));
}
ordered.append(owned);
for (const QString& dbName : ordered) {
if (!socketPath.isEmpty()) {
const QString encoded = QString::fromUtf8(QUrl::toPercentEncoding(socketPath));
candidates << QStringLiteral("postgresql:///%1?host=%2").arg(dbName, encoded);
}
candidates << QStringLiteral("postgresql://localhost/%1").arg(dbName);
}
candidates << QStringLiteral("postgresql://kompanion:komup@localhost/kompanion_test");
for (const QString& candidate : std::as_const(candidates)) {
if (candidate.trimmed().isEmpty()) continue;
if (testConnection(candidate.toStdString(), nullptr)) {
return candidate.toStdString();
}
}
return std::nullopt;
}
std::string jsonEscape(const QString& value) {
std::string out;
out.reserve(value.size());
for (QChar ch : value) {
const char c = static_cast<char>(ch.unicode());
switch (c) {
case '"': out += "\\\""; break;
case '\\': out += "\\\\"; break;
case '\b': out += "\\b"; break;
case '\f': out += "\\f"; break;
case '\n': out += "\\n"; break;
case '\r': out += "\\r"; break;
case '\t': out += "\\t"; break;
default:
if (static_cast<unsigned char>(c) < 0x20) {
char buffer[7];
std::snprintf(buffer, sizeof(buffer), "\\u%04x", static_cast<unsigned>(c));
out += buffer;
} else {
out += c;
}
break;
}
}
return out;
}
std::string makePromptPayload(const QString& prompt) {
return std::string("{\"prompt\":\"") + jsonEscape(prompt) + "\"}";
}
bool runInitializationWizard(QTextStream& in,
QTextStream& out,
bool verbose) {
out << "Kompanion initialization wizard\n"
<< "--------------------------------\n";
if (!QSqlDatabase::isDriverAvailable(QStringLiteral("QPSQL"))) {
out << "QPSQL driver not available. Please install the Qt PostgreSQL plugin (qt6-base or qt6-psql).\n";
return false;
}
const auto detected = autoDetectDsn();
ConnectionConfig cfg = configFromDsn(detected);
if (detected) {
out << "Detected working database at: " << QString::fromStdString(*detected) << "\n";
if (!promptYesNo(in, out, QStringLiteral("Use this configuration?"), true)) {
// user will re-enter below
} else {
const std::string dsn = *detected;
writeDsnToConfig(dsn);
::setenv("PG_DSN", dsn.c_str(), 1);
ensureSchema(dsn, out, verbose);
return true;
}
}
for (int attempts = 0; attempts < 5; ++attempts) {
const QString host = promptWithDefault(in, out, QStringLiteral("Host"), cfg.host);
const QString port = promptWithDefault(in, out, QStringLiteral("Port"), cfg.port);
const QString db = promptWithDefault(in, out, QStringLiteral("Database name"), cfg.dbname);
const QString user = promptWithDefault(in, out, QStringLiteral("User"), cfg.user);
const QString password = promptWithDefault(in, out, QStringLiteral("Password"), cfg.password, true);
const bool useSocket = promptYesNo(in, out, QStringLiteral("Use Unix socket connection?"), cfg.useSocket);
QString socketPath = cfg.socketPath;
if (useSocket) {
socketPath = promptWithDefault(in, out, QStringLiteral("Socket path"), cfg.socketPath);
}
ConnectionConfig entered;
entered.host = host;
entered.port = port;
entered.dbname = db;
entered.user = user;
entered.password = password;
entered.useSocket = useSocket;
entered.socketPath = socketPath;
const std::string dsn = buildDsn(entered);
QString error;
if (!testConnection(dsn, &error)) {
out << "Connection failed: " << error << "\n";
if (!promptYesNo(in, out, QStringLiteral("Try again?"), true)) {
return false;
}
cfg = entered;
continue;
}
writeDsnToConfig(dsn);
::setenv("PG_DSN", dsn.c_str(), 1);
ensureSchema(dsn, out, verbose);
return true;
}
out << "Too many failed attempts.\n";
return false;
}
int runInteractiveSession(KomMcpServer& server,
const std::string& toolName,
bool verbose) {
QTextStream out(stdout);
QTextStream in(stdin);
out << "Interactive MCP session with tool `" << QString::fromStdString(toolName) << "`.\n"
<< "Enter JSON payloads, `!prompt <text>` to wrap plain text, or an empty line to exit.\n";
for (;;) {
out << "json> " << Qt::flush;
QString line = in.readLine();
if (line.isNull()) break;
QString trimmed = line.trimmed();
if (trimmed.isEmpty() || trimmed == QStringLiteral("quit") || trimmed == QStringLiteral("exit")) {
break;
}
std::string payload;
if (trimmed.startsWith(QStringLiteral("!prompt"))) {
const QString promptText = trimmed.mid(QStringLiteral("!prompt").length()).trimmed();
payload = makePromptPayload(promptText);
} else {
payload = line.toStdString();
}
if (verbose) {
out << "[request] " << QString::fromStdString(payload) << "\n";
out.flush();
}
const std::string response = server.dispatch(toolName, payload);
if (verbose) {
out << "[response] " << QString::fromStdString(response) << "\n";
} else {
out << QString::fromStdString(response) << "\n";
}
}
return 0;
}
bool resolveRequestPayload(const QCommandLineParser& parser,
const QStringList& positional,
const QCommandLineOption& requestOption,
const QCommandLineOption& stdinOption,
std::string& payloadOut,
QString* error) {
if (parser.isSet(stdinOption)) {
payloadOut = readAll(std::cin);
return true;
}
if (parser.isSet(requestOption)) {
const QString arg = parser.value(requestOption);
if (arg == "-" || parser.isSet(stdinOption)) {
payloadOut = readAll(std::cin);
return true;
}
if (looksLikeFile(arg)) {
return readFileUtf8(arg, payloadOut, error);
}
payloadOut = arg.toStdString();
return true;
}
if (positional.size() > 1) {
const QString arg = positional.at(1);
if (arg == "-") {
payloadOut = readAll(std::cin);
return true;
}
if (looksLikeFile(arg)) {
return readFileUtf8(arg, payloadOut, error);
}
payloadOut = arg.toStdString();
return true;
}
payloadOut = "{}";
return true;
}
void printToolList(const KomMcpServer& server) {
QTextStream out(stdout);
const auto tools = server.listTools();
for (const auto& tool : tools) {
out << QString::fromStdString(tool) << '\n';
}
out.flush();
}
} // namespace
int main(int argc, char** argv) {
QCoreApplication app(argc, argv);
QCoreApplication::setApplicationName("Kompanion");
QCoreApplication::setApplicationVersion("0.1.0");
QCommandLineParser parser;
parser.setApplicationDescription("Kompanion MCP command-line client for personal memory tools.");
parser.addHelpOption();
parser.addVersionOption();
QCommandLineOption listOption(QStringList() << "l" << "list",
"List available tools and exit.");
parser.addOption(listOption);
QCommandLineOption initOption(QStringList() << "init",
"Run the configuration wizard before executing commands.");
parser.addOption(initOption);
QCommandLineOption requestOption(QStringList() << "r" << "request",
"JSON request payload or path to a JSON file.",
"payload");
parser.addOption(requestOption);
QCommandLineOption stdinOption(QStringList() << "i" << "stdin",
"Read request payload from standard input.");
parser.addOption(stdinOption);
QCommandLineOption interactiveOption(QStringList() << "I" << "interactive",
"Enter interactive prompt mode for repeated requests.");
parser.addOption(interactiveOption);
QCommandLineOption verboseOption(QStringList() << "V" << "verbose",
"Verbose mode; echo JSON request/response streams.");
parser.addOption(verboseOption);
QCommandLineOption dsnOption(QStringList() << "d" << "dsn",
"Override the Postgres DSN used by the DAL (sets PG_DSN).",
"dsn");
parser.addOption(dsnOption);
parser.addPositionalArgument("tool", "Tool name to invoke.");
parser.addPositionalArgument("payload", "Optional JSON payload or file path (use '-' for stdin).", "[payload]");
parser.process(app);
QTextStream qin(stdin);
QTextStream qout(stdout);
QTextStream qerr(stderr);
const bool verbose = parser.isSet(verboseOption);
const bool interactive = parser.isSet(interactiveOption);
const bool initRequested = parser.isSet(initOption);
std::optional<std::string> configDsn = readDsnFromConfig();
const char* envDsn = std::getenv("PG_DSN");
if (parser.isSet(dsnOption)) {
const QByteArray value = parser.value(dsnOption).toUtf8();
::setenv("PG_DSN", value.constData(), 1);
envDsn = std::getenv("PG_DSN");
}
const bool needInit = (!envDsn || !*envDsn) && !configDsn;
if (initRequested || needInit) {
if (!runInitializationWizard(qin, qout, verbose)) {
qerr << "Initialization aborted.\n";
if (initRequested && parser.positionalArguments().isEmpty()) {
return 1;
}
} else {
configDsn = readDsnFromConfig();
envDsn = std::getenv("PG_DSN");
}
}
if (!parser.isSet(dsnOption)) {
if (!envDsn || !*envDsn) {
if (configDsn) {
::setenv("PG_DSN", configDsn->c_str(), 1);
envDsn = std::getenv("PG_DSN");
}
}
}
KomMcpServer server;
register_default_tools(server);
if (initRequested && parser.positionalArguments().isEmpty()) {
qout << "Configuration complete.\n";
return 0;
}
if (parser.isSet(listOption)) {
printToolList(server);
return 0;
}
const QStringList positional = parser.positionalArguments();
if (positional.isEmpty()) {
parser.showHelp(1);
}
const std::string toolName = positional.first().toStdString();
if (!server.hasTool(toolName)) {
std::cerr << "Unknown tool: " << toolName << "\n";
printToolList(server);
return 1;
}
if (interactive) {
return runInteractiveSession(server, toolName, verbose);
}
std::string request;
QString requestError;
if (!resolveRequestPayload(parser,
positional,
requestOption,
stdinOption,
request,
&requestError)) {
const QString message = requestError.isEmpty()
? QStringLiteral("Failed to resolve request payload.")
: requestError;
std::cerr << "Error: " << message.toStdString() << "\n";
return 1;
}
if (verbose) {
std::cerr << "[request] " << request << "\n";
}
const std::string response = server.dispatch(toolName, request);
if (verbose) {
std::cerr << "[response] " << response << "\n";
}
std::cout << response << std::endl;
return 0;
}