ki: implement multi-text embeddings in OllamaProvider::embed by batching one request per text and aggregating results
This commit is contained in:
parent
e1eb3afacc
commit
2ee757b64a
|
|
@ -108,43 +108,43 @@ QFuture<KIReply*> OllamaProvider::chat(const KIThread& thread, const KIChatOptio
|
||||||
|
|
||||||
QFuture<KIEmbeddingResult> OllamaProvider::embed(const QStringList& texts, const KIEmbedOptions& opts)
|
QFuture<KIEmbeddingResult> OllamaProvider::embed(const QStringList& texts, const KIEmbedOptions& opts)
|
||||||
{
|
{
|
||||||
QNetworkRequest req{QUrl(QStringLiteral("http://localhost:11434/api/embeddings"))};
|
// Execute one request per input text; aggregate outputs.
|
||||||
|
QFutureInterface<KIEmbeddingResult> fi;
|
||||||
|
fi.reportStarted();
|
||||||
|
if (texts.isEmpty()) { KIEmbeddingResult r; r.model = opts.model; fi.reportResult(r); fi.reportFinished(); return fi.future(); }
|
||||||
|
|
||||||
|
struct Accum { QVector<QVector<float>> vectors; int remaining = 0; QString model; };
|
||||||
|
auto acc = new Accum();
|
||||||
|
acc->vectors.resize(texts.size());
|
||||||
|
acc->remaining = texts.size();
|
||||||
|
|
||||||
|
const QUrl url(QStringLiteral("http://localhost:11434/api/embeddings"));
|
||||||
|
for (int i = 0; i < texts.size(); ++i) {
|
||||||
|
QNetworkRequest req{url};
|
||||||
req.setHeader(QNetworkRequest::ContentTypeHeader, QStringLiteral("application/json"));
|
req.setHeader(QNetworkRequest::ContentTypeHeader, QStringLiteral("application/json"));
|
||||||
|
const QJsonObject body{ {QStringLiteral("model"), opts.model}, {QStringLiteral("prompt"), texts[i]} };
|
||||||
QJsonObject data;
|
auto rep = m_manager->post(req, QJsonDocument(body).toJson());
|
||||||
data["model"] = opts.model;
|
connect(rep, &QNetworkReply::finished, this, [rep, i, acc, fi]() mutable {
|
||||||
data["prompt"] = texts.join("\n"); // Join all texts into a single prompt
|
if (rep->error() == QNetworkReply::NoError) {
|
||||||
|
const auto obj = QJsonDocument::fromJson(rep->readAll()).object();
|
||||||
auto netReply = m_manager->post(req, QJsonDocument(data).toJson());
|
if (acc->model.isEmpty()) acc->model = obj.value(QStringLiteral("model")).toString();
|
||||||
|
const auto arr = obj.value(QStringLiteral("embedding")).toArray();
|
||||||
QFutureInterface<KIEmbeddingResult> interface;
|
QVector<float> vec; vec.reserve(arr.size());
|
||||||
interface.reportStarted();
|
for (const auto &v : arr) vec.push_back(static_cast<float>(v.toDouble()));
|
||||||
|
acc->vectors[i] = std::move(vec);
|
||||||
connect(netReply, &QNetworkReply::finished, this, [netReply, interface]() mutable {
|
|
||||||
if (netReply->error() != QNetworkReply::NoError) {
|
|
||||||
// TODO: Handle error
|
|
||||||
interface.reportFinished();
|
|
||||||
netReply->deleteLater();
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
rep->deleteLater();
|
||||||
const auto json = QJsonDocument::fromJson(netReply->readAll());
|
acc->remaining -= 1;
|
||||||
const auto embeddingArray = json["embedding"].toArray();
|
if (acc->remaining == 0) {
|
||||||
|
KIEmbeddingResult res; res.vectors = std::move(acc->vectors); res.model = acc->model;
|
||||||
KIEmbeddingResult result;
|
fi.reportResult(res);
|
||||||
QVector<float> embedding;
|
fi.reportFinished();
|
||||||
for (const QJsonValue &value : embeddingArray) {
|
delete acc;
|
||||||
embedding.push_back(value.toDouble());
|
|
||||||
}
|
}
|
||||||
result.vectors.push_back(embedding);
|
|
||||||
result.model = json["model"].toString();
|
|
||||||
|
|
||||||
interface.reportResult(result);
|
|
||||||
interface.reportFinished();
|
|
||||||
netReply->deleteLater();
|
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
return interface.future();
|
return fi.future();
|
||||||
}
|
}
|
||||||
|
|
||||||
void OllamaProvider::cancel(quint64 requestId)
|
void OllamaProvider::cancel(quint64 requestId)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue