mirror of
https://github.com/Syfaro/fuzzysearch.git
synced 2024-11-05 14:32:56 +00:00
Ability to stream responses for faster updates.
This commit is contained in:
parent
cad0016522
commit
be5e1a9b97
@ -6,7 +6,8 @@ use warp::{Filter, Rejection, Reply};
|
||||
pub fn search(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
|
||||
search_file(db.clone())
|
||||
.or(search_image(db.clone()))
|
||||
.or(search_hashes(db))
|
||||
.or(search_hashes(db.clone()))
|
||||
.or(stream_search_image(db))
|
||||
}
|
||||
|
||||
pub fn search_file(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
|
||||
@ -37,6 +38,17 @@ pub fn search_hashes(db: Pool) -> impl Filter<Extract = impl Reply, Error = Reje
|
||||
.and_then(handlers::search_hashes)
|
||||
}
|
||||
|
||||
pub fn stream_search_image(
|
||||
db: Pool,
|
||||
) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
|
||||
warp::path("stream")
|
||||
.and(warp::post())
|
||||
.and(warp::multipart::form().max_length(1024 * 1024 * 10))
|
||||
.and(with_pool(db))
|
||||
.and(with_api_key())
|
||||
.and_then(handlers::stream_image)
|
||||
}
|
||||
|
||||
fn with_api_key() -> impl Filter<Extract = (String,), Error = Rejection> + Clone {
|
||||
warp::header::<String>("x-api-key")
|
||||
}
|
||||
|
106
src/handlers.rs
106
src/handlers.rs
@ -1,6 +1,5 @@
|
||||
use crate::models::image_query;
|
||||
use crate::models::{image_query, image_query_sync};
|
||||
use crate::types::*;
|
||||
use crate::utils::{extract_e621_rows, extract_fa_rows, extract_twitter_rows};
|
||||
use crate::{rate_limit, Pool};
|
||||
use log::{debug, info};
|
||||
use warp::{reject, Rejection, Reply};
|
||||
@ -39,10 +38,10 @@ impl warp::reject::Reject for Error {}
|
||||
pub async fn search_image(
|
||||
form: warp::multipart::FormData,
|
||||
opts: ImageSearchOpts,
|
||||
db: Pool,
|
||||
pool: Pool,
|
||||
api_key: String,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
let db = db.get().await.map_err(map_bb8_err)?;
|
||||
let db = pool.get().await.map_err(map_bb8_err)?;
|
||||
|
||||
rate_limit!(&api_key, &db, image_limit, "image");
|
||||
|
||||
@ -79,28 +78,25 @@ pub async fn search_image(
|
||||
|
||||
debug!("Matching hash {}", num);
|
||||
|
||||
let results = {
|
||||
let mut items = {
|
||||
if opts.search_type == Some(ImageSearchType::Force) {
|
||||
image_query(&db, vec![num], 10).await.unwrap()
|
||||
image_query(pool.clone(), vec![num], 10, Some(hash.as_bytes().to_vec()))
|
||||
.await
|
||||
.unwrap()
|
||||
} else {
|
||||
let results = image_query(&db, vec![num], 0).await.unwrap();
|
||||
let results = image_query(pool.clone(), vec![num], 0, Some(hash.as_bytes().to_vec()))
|
||||
.await
|
||||
.unwrap();
|
||||
if results.is_empty() && opts.search_type != Some(ImageSearchType::Exact) {
|
||||
image_query(&db, vec![num], 10).await.unwrap()
|
||||
image_query(pool.clone(), vec![num], 10, Some(hash.as_bytes().to_vec()))
|
||||
.await
|
||||
.unwrap()
|
||||
} else {
|
||||
results
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let mut items = Vec::with_capacity(results.len());
|
||||
|
||||
items.extend(extract_fa_rows(results.furaffinity, Some(&hash.as_bytes())));
|
||||
items.extend(extract_e621_rows(results.e621, Some(&hash.as_bytes())));
|
||||
items.extend(extract_twitter_rows(
|
||||
results.twitter,
|
||||
Some(&hash.as_bytes()),
|
||||
));
|
||||
|
||||
items.sort_by(|a, b| {
|
||||
a.distance
|
||||
.unwrap_or(u64::max_value())
|
||||
@ -116,11 +112,75 @@ pub async fn search_image(
|
||||
Ok(warp::reply::json(&similarity))
|
||||
}
|
||||
|
||||
pub async fn stream_image(
|
||||
form: warp::multipart::FormData,
|
||||
pool: Pool,
|
||||
api_key: String,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
let db = pool.get().await.map_err(map_bb8_err)?;
|
||||
|
||||
rate_limit!(&api_key, &db, image_limit, "image", 2);
|
||||
|
||||
use bytes::BufMut;
|
||||
use futures_util::StreamExt;
|
||||
let parts: Vec<_> = form.collect().await;
|
||||
let mut parts = parts
|
||||
.into_iter()
|
||||
.map(|part| {
|
||||
let part = part.unwrap();
|
||||
(part.name().to_string(), part)
|
||||
})
|
||||
.collect::<std::collections::HashMap<_, _>>();
|
||||
let image = parts.remove("image").unwrap();
|
||||
|
||||
let bytes = image
|
||||
.stream()
|
||||
.fold(bytes::BytesMut::new(), |mut b, data| {
|
||||
b.put(data.unwrap());
|
||||
async move { b }
|
||||
})
|
||||
.await;
|
||||
|
||||
let hash = {
|
||||
let hasher = crate::get_hasher();
|
||||
let image = image::load_from_memory(&bytes).unwrap();
|
||||
hasher.hash_image(&image)
|
||||
};
|
||||
|
||||
let mut buf: [u8; 8] = [0; 8];
|
||||
buf.copy_from_slice(&hash.as_bytes());
|
||||
|
||||
let num = i64::from_be_bytes(buf);
|
||||
|
||||
debug!("Stream matching hash {}", num);
|
||||
|
||||
let exact_event_stream =
|
||||
image_query_sync(pool.clone(), vec![num], 0, Some(hash.as_bytes().to_vec()))
|
||||
.map(sse_matches);
|
||||
|
||||
let close_event_stream =
|
||||
image_query_sync(pool.clone(), vec![num], 10, Some(hash.as_bytes().to_vec()))
|
||||
.map(sse_matches);
|
||||
|
||||
let event_stream = futures::stream::select(exact_event_stream, close_event_stream);
|
||||
|
||||
Ok(warp::sse::reply(event_stream))
|
||||
}
|
||||
|
||||
fn sse_matches(
|
||||
matches: Result<Vec<File>, tokio_postgres::Error>,
|
||||
) -> Result<impl warp::sse::ServerSentEvent, core::convert::Infallible> {
|
||||
let items = matches.unwrap();
|
||||
|
||||
Ok(warp::sse::json(items))
|
||||
}
|
||||
|
||||
pub async fn search_hashes(
|
||||
opts: HashSearchOpts,
|
||||
db: Pool,
|
||||
api_key: String,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
let pool = db.clone();
|
||||
let db = db.get().await.map_err(map_bb8_err)?;
|
||||
|
||||
let hashes: Vec<i64> = opts
|
||||
@ -136,14 +196,12 @@ pub async fn search_hashes(
|
||||
|
||||
rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16);
|
||||
|
||||
let results = image_query(&db, hashes, 10)
|
||||
.await
|
||||
.map_err(|err| reject::custom(Error::from(err)))?;
|
||||
let mut results = image_query_sync(pool, hashes.clone(), 10, None);
|
||||
let mut matches = Vec::new();
|
||||
|
||||
let mut matches = Vec::with_capacity(results.len());
|
||||
matches.extend(extract_fa_rows(results.furaffinity, None));
|
||||
matches.extend(extract_e621_rows(results.e621, None));
|
||||
matches.extend(extract_twitter_rows(results.twitter, None));
|
||||
while let Some(r) = results.recv().await {
|
||||
matches.extend(r.map_err(|e| warp::reject::custom(Error::Postgres(e)))?);
|
||||
}
|
||||
|
||||
Ok(warp::reply::json(&matches))
|
||||
}
|
||||
|
@ -1,3 +1,5 @@
|
||||
#![recursion_limit = "256"]
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
mod filters;
|
||||
|
217
src/models.rs
217
src/models.rs
@ -1,4 +1,6 @@
|
||||
use crate::types::*;
|
||||
use crate::utils::{extract_e621_rows, extract_fa_rows, extract_twitter_rows};
|
||||
use crate::Pool;
|
||||
|
||||
pub type DB<'a> =
|
||||
&'a bb8::PooledConnection<'a, bb8_postgres::PostgresConnectionManager<tokio_postgres::NoTls>>;
|
||||
@ -35,107 +37,128 @@ pub async fn lookup_api_key(key: &str, db: DB<'_>) -> Option<ApiKey> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ImageQueryResults {
|
||||
pub furaffinity: Vec<tokio_postgres::Row>,
|
||||
pub e621: Vec<tokio_postgres::Row>,
|
||||
pub twitter: Vec<tokio_postgres::Row>,
|
||||
}
|
||||
|
||||
impl ImageQueryResults {
|
||||
#[inline]
|
||||
pub fn len(&self) -> usize {
|
||||
self.furaffinity.len() + self.e621.len() + self.twitter.len()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn image_query(
|
||||
db: DB<'_>,
|
||||
pool: Pool,
|
||||
hashes: Vec<i64>,
|
||||
distance: i64,
|
||||
) -> Result<ImageQueryResults, tokio_postgres::Error> {
|
||||
let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
|
||||
Vec::with_capacity(hashes.len() + 1);
|
||||
params.insert(0, &distance);
|
||||
hash: Option<Vec<u8>>,
|
||||
) -> Result<Vec<File>, tokio_postgres::Error> {
|
||||
let mut results = image_query_sync(pool, hashes, distance, hash);
|
||||
let mut matches = Vec::new();
|
||||
|
||||
let mut fa_where_clause = Vec::with_capacity(hashes.len());
|
||||
let mut hash_where_clause = Vec::with_capacity(hashes.len());
|
||||
|
||||
for (idx, hash) in hashes.iter().enumerate() {
|
||||
params.push(hash);
|
||||
|
||||
fa_where_clause.push(format!(" hash_int <@ (${}, $1)", idx + 2));
|
||||
hash_where_clause.push(format!(" hash <@ (${}, $1)", idx + 2));
|
||||
while let Some(r) = results.recv().await {
|
||||
matches.extend(r?);
|
||||
}
|
||||
let hash_where_clause = hash_where_clause.join(" OR ");
|
||||
|
||||
let fa_query = format!(
|
||||
"SELECT
|
||||
submission.id,
|
||||
submission.url,
|
||||
submission.filename,
|
||||
submission.file_id,
|
||||
submission.hash,
|
||||
submission.hash_int,
|
||||
artist.name
|
||||
FROM
|
||||
submission
|
||||
JOIN artist
|
||||
ON artist.id = submission.artist_id
|
||||
WHERE
|
||||
{}",
|
||||
fa_where_clause.join(" OR ")
|
||||
);
|
||||
|
||||
let e621_query = format!(
|
||||
"SELECT
|
||||
e621.id,
|
||||
e621.hash,
|
||||
e621.data->>'file_url' url,
|
||||
e621.data->>'md5' md5,
|
||||
sources.list sources,
|
||||
artists.list artists,
|
||||
(e621.data->>'md5') || '.' || (e621.data->>'file_ext') filename
|
||||
FROM
|
||||
e621,
|
||||
LATERAL (
|
||||
SELECT array_agg(s) list
|
||||
FROM jsonb_array_elements_text(data->'sources') s
|
||||
) sources,
|
||||
LATERAL (
|
||||
SELECT array_agg(s) list
|
||||
FROM jsonb_array_elements_text(data->'artist') s
|
||||
) artists
|
||||
WHERE
|
||||
{}",
|
||||
&hash_where_clause
|
||||
);
|
||||
|
||||
let twitter_query = format!(
|
||||
"SELECT
|
||||
twitter_view.id,
|
||||
twitter_view.artists,
|
||||
twitter_view.url,
|
||||
twitter_view.hash
|
||||
FROM
|
||||
twitter_view
|
||||
WHERE
|
||||
{}",
|
||||
&hash_where_clause
|
||||
);
|
||||
|
||||
let furaffinity = db.query::<str>(&*fa_query, ¶ms);
|
||||
let e621 = db.query::<str>(&*e621_query, ¶ms);
|
||||
let twitter = db.query::<str>(&*twitter_query, ¶ms);
|
||||
|
||||
let results = futures::future::join3(furaffinity, e621, twitter).await;
|
||||
Ok(ImageQueryResults {
|
||||
furaffinity: results.0?,
|
||||
e621: results.1?,
|
||||
twitter: results.2?,
|
||||
})
|
||||
Ok(matches)
|
||||
}
|
||||
|
||||
pub fn image_query_sync(
|
||||
pool: Pool,
|
||||
hashes: Vec<i64>,
|
||||
distance: i64,
|
||||
hash: Option<Vec<u8>>,
|
||||
) -> tokio::sync::mpsc::Receiver<Result<Vec<File>, tokio_postgres::Error>> {
|
||||
use futures_util::FutureExt;
|
||||
|
||||
let (mut tx, rx) = tokio::sync::mpsc::channel(3);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let db = pool.get().await.unwrap();
|
||||
|
||||
let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
|
||||
Vec::with_capacity(hashes.len() + 1);
|
||||
params.insert(0, &distance);
|
||||
|
||||
let mut fa_where_clause = Vec::with_capacity(hashes.len());
|
||||
let mut hash_where_clause = Vec::with_capacity(hashes.len());
|
||||
|
||||
for (idx, hash) in hashes.iter().enumerate() {
|
||||
params.push(hash);
|
||||
|
||||
fa_where_clause.push(format!(" hash_int <@ (${}, $1)", idx + 2));
|
||||
hash_where_clause.push(format!(" hash <@ (${}, $1)", idx + 2));
|
||||
}
|
||||
let hash_where_clause = hash_where_clause.join(" OR ");
|
||||
|
||||
let fa_query = format!(
|
||||
"SELECT
|
||||
submission.id,
|
||||
submission.url,
|
||||
submission.filename,
|
||||
submission.file_id,
|
||||
submission.hash,
|
||||
submission.hash_int,
|
||||
artist.name
|
||||
FROM
|
||||
submission
|
||||
JOIN artist
|
||||
ON artist.id = submission.artist_id
|
||||
WHERE
|
||||
{}",
|
||||
fa_where_clause.join(" OR ")
|
||||
);
|
||||
|
||||
let e621_query = format!(
|
||||
"SELECT
|
||||
e621.id,
|
||||
e621.hash,
|
||||
e621.data->>'file_url' url,
|
||||
e621.data->>'md5' md5,
|
||||
sources.list sources,
|
||||
artists.list artists,
|
||||
(e621.data->>'md5') || '.' || (e621.data->>'file_ext') filename
|
||||
FROM
|
||||
e621,
|
||||
LATERAL (
|
||||
SELECT array_agg(s) list
|
||||
FROM jsonb_array_elements_text(data->'sources') s
|
||||
) sources,
|
||||
LATERAL (
|
||||
SELECT array_agg(s) list
|
||||
FROM jsonb_array_elements_text(data->'artist') s
|
||||
) artists
|
||||
WHERE
|
||||
{}",
|
||||
&hash_where_clause
|
||||
);
|
||||
|
||||
let twitter_query = format!(
|
||||
"SELECT
|
||||
twitter_view.id,
|
||||
twitter_view.artists,
|
||||
twitter_view.url,
|
||||
twitter_view.hash
|
||||
FROM
|
||||
twitter_view
|
||||
WHERE
|
||||
{}",
|
||||
&hash_where_clause
|
||||
);
|
||||
|
||||
let mut furaffinity = Box::pin(db.query::<str>(&*fa_query, ¶ms).fuse());
|
||||
let mut e621 = Box::pin(db.query::<str>(&*e621_query, ¶ms).fuse());
|
||||
let mut twitter = Box::pin(db.query::<str>(&*twitter_query, ¶ms).fuse());
|
||||
|
||||
#[allow(clippy::unnecessary_mut_passed)]
|
||||
loop {
|
||||
futures::select! {
|
||||
fa = furaffinity => {
|
||||
let rows = fa.map(|rows| extract_fa_rows(rows, hash.as_deref()).into_iter().collect());
|
||||
tx.send(rows).await.unwrap();
|
||||
}
|
||||
e = e621 => {
|
||||
let rows = e.map(|rows| extract_e621_rows(rows, hash.as_deref()).into_iter().collect());
|
||||
tx.send(rows).await.unwrap();
|
||||
}
|
||||
t = twitter => {
|
||||
let rows = t.map(|rows| extract_twitter_rows(rows, hash.as_deref()).into_iter().collect());
|
||||
tx.send(rows).await.unwrap();
|
||||
}
|
||||
complete => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
rx
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user