Ability to stream responses for faster updates.

This commit is contained in:
Syfaro 2020-01-24 00:29:13 -06:00
parent cad0016522
commit be5e1a9b97
4 changed files with 217 additions and 122 deletions

View File

@ -6,7 +6,8 @@ use warp::{Filter, Rejection, Reply};
pub fn search(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone { pub fn search(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
search_file(db.clone()) search_file(db.clone())
.or(search_image(db.clone())) .or(search_image(db.clone()))
.or(search_hashes(db)) .or(search_hashes(db.clone()))
.or(stream_search_image(db))
} }
pub fn search_file(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone { pub fn search_file(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
@ -37,6 +38,17 @@ pub fn search_hashes(db: Pool) -> impl Filter<Extract = impl Reply, Error = Reje
.and_then(handlers::search_hashes) .and_then(handlers::search_hashes)
} }
pub fn stream_search_image(
db: Pool,
) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
warp::path("stream")
.and(warp::post())
.and(warp::multipart::form().max_length(1024 * 1024 * 10))
.and(with_pool(db))
.and(with_api_key())
.and_then(handlers::stream_image)
}
fn with_api_key() -> impl Filter<Extract = (String,), Error = Rejection> + Clone { fn with_api_key() -> impl Filter<Extract = (String,), Error = Rejection> + Clone {
warp::header::<String>("x-api-key") warp::header::<String>("x-api-key")
} }

View File

@ -1,6 +1,5 @@
use crate::models::image_query; use crate::models::{image_query, image_query_sync};
use crate::types::*; use crate::types::*;
use crate::utils::{extract_e621_rows, extract_fa_rows, extract_twitter_rows};
use crate::{rate_limit, Pool}; use crate::{rate_limit, Pool};
use log::{debug, info}; use log::{debug, info};
use warp::{reject, Rejection, Reply}; use warp::{reject, Rejection, Reply};
@ -39,10 +38,10 @@ impl warp::reject::Reject for Error {}
pub async fn search_image( pub async fn search_image(
form: warp::multipart::FormData, form: warp::multipart::FormData,
opts: ImageSearchOpts, opts: ImageSearchOpts,
db: Pool, pool: Pool,
api_key: String, api_key: String,
) -> Result<impl Reply, Rejection> { ) -> Result<impl Reply, Rejection> {
let db = db.get().await.map_err(map_bb8_err)?; let db = pool.get().await.map_err(map_bb8_err)?;
rate_limit!(&api_key, &db, image_limit, "image"); rate_limit!(&api_key, &db, image_limit, "image");
@ -79,28 +78,25 @@ pub async fn search_image(
debug!("Matching hash {}", num); debug!("Matching hash {}", num);
let results = { let mut items = {
if opts.search_type == Some(ImageSearchType::Force) { if opts.search_type == Some(ImageSearchType::Force) {
image_query(&db, vec![num], 10).await.unwrap() image_query(pool.clone(), vec![num], 10, Some(hash.as_bytes().to_vec()))
.await
.unwrap()
} else { } else {
let results = image_query(&db, vec![num], 0).await.unwrap(); let results = image_query(pool.clone(), vec![num], 0, Some(hash.as_bytes().to_vec()))
.await
.unwrap();
if results.is_empty() && opts.search_type != Some(ImageSearchType::Exact) { if results.is_empty() && opts.search_type != Some(ImageSearchType::Exact) {
image_query(&db, vec![num], 10).await.unwrap() image_query(pool.clone(), vec![num], 10, Some(hash.as_bytes().to_vec()))
.await
.unwrap()
} else { } else {
results results
} }
} }
}; };
let mut items = Vec::with_capacity(results.len());
items.extend(extract_fa_rows(results.furaffinity, Some(&hash.as_bytes())));
items.extend(extract_e621_rows(results.e621, Some(&hash.as_bytes())));
items.extend(extract_twitter_rows(
results.twitter,
Some(&hash.as_bytes()),
));
items.sort_by(|a, b| { items.sort_by(|a, b| {
a.distance a.distance
.unwrap_or(u64::max_value()) .unwrap_or(u64::max_value())
@ -116,11 +112,75 @@ pub async fn search_image(
Ok(warp::reply::json(&similarity)) Ok(warp::reply::json(&similarity))
} }
pub async fn stream_image(
form: warp::multipart::FormData,
pool: Pool,
api_key: String,
) -> Result<impl Reply, Rejection> {
let db = pool.get().await.map_err(map_bb8_err)?;
rate_limit!(&api_key, &db, image_limit, "image", 2);
use bytes::BufMut;
use futures_util::StreamExt;
let parts: Vec<_> = form.collect().await;
let mut parts = parts
.into_iter()
.map(|part| {
let part = part.unwrap();
(part.name().to_string(), part)
})
.collect::<std::collections::HashMap<_, _>>();
let image = parts.remove("image").unwrap();
let bytes = image
.stream()
.fold(bytes::BytesMut::new(), |mut b, data| {
b.put(data.unwrap());
async move { b }
})
.await;
let hash = {
let hasher = crate::get_hasher();
let image = image::load_from_memory(&bytes).unwrap();
hasher.hash_image(&image)
};
let mut buf: [u8; 8] = [0; 8];
buf.copy_from_slice(&hash.as_bytes());
let num = i64::from_be_bytes(buf);
debug!("Stream matching hash {}", num);
let exact_event_stream =
image_query_sync(pool.clone(), vec![num], 0, Some(hash.as_bytes().to_vec()))
.map(sse_matches);
let close_event_stream =
image_query_sync(pool.clone(), vec![num], 10, Some(hash.as_bytes().to_vec()))
.map(sse_matches);
let event_stream = futures::stream::select(exact_event_stream, close_event_stream);
Ok(warp::sse::reply(event_stream))
}
fn sse_matches(
matches: Result<Vec<File>, tokio_postgres::Error>,
) -> Result<impl warp::sse::ServerSentEvent, core::convert::Infallible> {
let items = matches.unwrap();
Ok(warp::sse::json(items))
}
pub async fn search_hashes( pub async fn search_hashes(
opts: HashSearchOpts, opts: HashSearchOpts,
db: Pool, db: Pool,
api_key: String, api_key: String,
) -> Result<impl Reply, Rejection> { ) -> Result<impl Reply, Rejection> {
let pool = db.clone();
let db = db.get().await.map_err(map_bb8_err)?; let db = db.get().await.map_err(map_bb8_err)?;
let hashes: Vec<i64> = opts let hashes: Vec<i64> = opts
@ -136,14 +196,12 @@ pub async fn search_hashes(
rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16); rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16);
let results = image_query(&db, hashes, 10) let mut results = image_query_sync(pool, hashes.clone(), 10, None);
.await let mut matches = Vec::new();
.map_err(|err| reject::custom(Error::from(err)))?;
let mut matches = Vec::with_capacity(results.len()); while let Some(r) = results.recv().await {
matches.extend(extract_fa_rows(results.furaffinity, None)); matches.extend(r.map_err(|e| warp::reject::custom(Error::Postgres(e)))?);
matches.extend(extract_e621_rows(results.e621, None)); }
matches.extend(extract_twitter_rows(results.twitter, None));
Ok(warp::reply::json(&matches)) Ok(warp::reply::json(&matches))
} }

View File

@ -1,3 +1,5 @@
#![recursion_limit = "256"]
use std::str::FromStr; use std::str::FromStr;
mod filters; mod filters;

View File

@ -1,4 +1,6 @@
use crate::types::*; use crate::types::*;
use crate::utils::{extract_e621_rows, extract_fa_rows, extract_twitter_rows};
use crate::Pool;
pub type DB<'a> = pub type DB<'a> =
&'a bb8::PooledConnection<'a, bb8_postgres::PostgresConnectionManager<tokio_postgres::NoTls>>; &'a bb8::PooledConnection<'a, bb8_postgres::PostgresConnectionManager<tokio_postgres::NoTls>>;
@ -35,107 +37,128 @@ pub async fn lookup_api_key(key: &str, db: DB<'_>) -> Option<ApiKey> {
} }
} }
pub struct ImageQueryResults {
pub furaffinity: Vec<tokio_postgres::Row>,
pub e621: Vec<tokio_postgres::Row>,
pub twitter: Vec<tokio_postgres::Row>,
}
impl ImageQueryResults {
#[inline]
pub fn len(&self) -> usize {
self.furaffinity.len() + self.e621.len() + self.twitter.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub async fn image_query( pub async fn image_query(
db: DB<'_>, pool: Pool,
hashes: Vec<i64>, hashes: Vec<i64>,
distance: i64, distance: i64,
) -> Result<ImageQueryResults, tokio_postgres::Error> { hash: Option<Vec<u8>>,
let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = ) -> Result<Vec<File>, tokio_postgres::Error> {
Vec::with_capacity(hashes.len() + 1); let mut results = image_query_sync(pool, hashes, distance, hash);
params.insert(0, &distance); let mut matches = Vec::new();
let mut fa_where_clause = Vec::with_capacity(hashes.len()); while let Some(r) = results.recv().await {
let mut hash_where_clause = Vec::with_capacity(hashes.len()); matches.extend(r?);
for (idx, hash) in hashes.iter().enumerate() {
params.push(hash);
fa_where_clause.push(format!(" hash_int <@ (${}, $1)", idx + 2));
hash_where_clause.push(format!(" hash <@ (${}, $1)", idx + 2));
} }
let hash_where_clause = hash_where_clause.join(" OR ");
let fa_query = format!( Ok(matches)
"SELECT }
submission.id,
submission.url, pub fn image_query_sync(
submission.filename, pool: Pool,
submission.file_id, hashes: Vec<i64>,
submission.hash, distance: i64,
submission.hash_int, hash: Option<Vec<u8>>,
artist.name ) -> tokio::sync::mpsc::Receiver<Result<Vec<File>, tokio_postgres::Error>> {
FROM use futures_util::FutureExt;
submission
JOIN artist let (mut tx, rx) = tokio::sync::mpsc::channel(3);
ON artist.id = submission.artist_id
WHERE tokio::spawn(async move {
{}", let db = pool.get().await.unwrap();
fa_where_clause.join(" OR ")
); let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
Vec::with_capacity(hashes.len() + 1);
let e621_query = format!( params.insert(0, &distance);
"SELECT
e621.id, let mut fa_where_clause = Vec::with_capacity(hashes.len());
e621.hash, let mut hash_where_clause = Vec::with_capacity(hashes.len());
e621.data->>'file_url' url,
e621.data->>'md5' md5, for (idx, hash) in hashes.iter().enumerate() {
sources.list sources, params.push(hash);
artists.list artists,
(e621.data->>'md5') || '.' || (e621.data->>'file_ext') filename fa_where_clause.push(format!(" hash_int <@ (${}, $1)", idx + 2));
FROM hash_where_clause.push(format!(" hash <@ (${}, $1)", idx + 2));
e621, }
LATERAL ( let hash_where_clause = hash_where_clause.join(" OR ");
SELECT array_agg(s) list
FROM jsonb_array_elements_text(data->'sources') s let fa_query = format!(
) sources, "SELECT
LATERAL ( submission.id,
SELECT array_agg(s) list submission.url,
FROM jsonb_array_elements_text(data->'artist') s submission.filename,
) artists submission.file_id,
WHERE submission.hash,
{}", submission.hash_int,
&hash_where_clause artist.name
); FROM
submission
let twitter_query = format!( JOIN artist
"SELECT ON artist.id = submission.artist_id
twitter_view.id, WHERE
twitter_view.artists, {}",
twitter_view.url, fa_where_clause.join(" OR ")
twitter_view.hash );
FROM
twitter_view let e621_query = format!(
WHERE "SELECT
{}", e621.id,
&hash_where_clause e621.hash,
); e621.data->>'file_url' url,
e621.data->>'md5' md5,
let furaffinity = db.query::<str>(&*fa_query, &params); sources.list sources,
let e621 = db.query::<str>(&*e621_query, &params); artists.list artists,
let twitter = db.query::<str>(&*twitter_query, &params); (e621.data->>'md5') || '.' || (e621.data->>'file_ext') filename
FROM
let results = futures::future::join3(furaffinity, e621, twitter).await; e621,
Ok(ImageQueryResults { LATERAL (
furaffinity: results.0?, SELECT array_agg(s) list
e621: results.1?, FROM jsonb_array_elements_text(data->'sources') s
twitter: results.2?, ) sources,
}) LATERAL (
SELECT array_agg(s) list
FROM jsonb_array_elements_text(data->'artist') s
) artists
WHERE
{}",
&hash_where_clause
);
let twitter_query = format!(
"SELECT
twitter_view.id,
twitter_view.artists,
twitter_view.url,
twitter_view.hash
FROM
twitter_view
WHERE
{}",
&hash_where_clause
);
let mut furaffinity = Box::pin(db.query::<str>(&*fa_query, &params).fuse());
let mut e621 = Box::pin(db.query::<str>(&*e621_query, &params).fuse());
let mut twitter = Box::pin(db.query::<str>(&*twitter_query, &params).fuse());
#[allow(clippy::unnecessary_mut_passed)]
loop {
futures::select! {
fa = furaffinity => {
let rows = fa.map(|rows| extract_fa_rows(rows, hash.as_deref()).into_iter().collect());
tx.send(rows).await.unwrap();
}
e = e621 => {
let rows = e.map(|rows| extract_e621_rows(rows, hash.as_deref()).into_iter().collect());
tx.send(rows).await.unwrap();
}
t = twitter => {
let rows = t.map(|rows| extract_twitter_rows(rows, hash.as_deref()).into_iter().collect());
tx.send(rows).await.unwrap();
}
complete => break,
}
}
});
rx
} }