From be5e1a9b97d68df7bf71dd57f0a3bac5a0c202d4 Mon Sep 17 00:00:00 2001 From: Syfaro Date: Fri, 24 Jan 2020 00:29:13 -0600 Subject: [PATCH] Ability to stream responses for faster updates. --- src/filters.rs | 14 +++- src/handlers.rs | 106 +++++++++++++++++------ src/main.rs | 2 + src/models.rs | 217 ++++++++++++++++++++++++++---------------------- 4 files changed, 217 insertions(+), 122 deletions(-) diff --git a/src/filters.rs b/src/filters.rs index 3ce48bf..8b52041 100644 --- a/src/filters.rs +++ b/src/filters.rs @@ -6,7 +6,8 @@ use warp::{Filter, Rejection, Reply}; pub fn search(db: Pool) -> impl Filter + Clone { search_file(db.clone()) .or(search_image(db.clone())) - .or(search_hashes(db)) + .or(search_hashes(db.clone())) + .or(stream_search_image(db)) } pub fn search_file(db: Pool) -> impl Filter + Clone { @@ -37,6 +38,17 @@ pub fn search_hashes(db: Pool) -> impl Filter impl Filter + Clone { + warp::path("stream") + .and(warp::post()) + .and(warp::multipart::form().max_length(1024 * 1024 * 10)) + .and(with_pool(db)) + .and(with_api_key()) + .and_then(handlers::stream_image) +} + fn with_api_key() -> impl Filter + Clone { warp::header::("x-api-key") } diff --git a/src/handlers.rs b/src/handlers.rs index 84d653a..449e322 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -1,6 +1,5 @@ -use crate::models::image_query; +use crate::models::{image_query, image_query_sync}; use crate::types::*; -use crate::utils::{extract_e621_rows, extract_fa_rows, extract_twitter_rows}; use crate::{rate_limit, Pool}; use log::{debug, info}; use warp::{reject, Rejection, Reply}; @@ -39,10 +38,10 @@ impl warp::reject::Reject for Error {} pub async fn search_image( form: warp::multipart::FormData, opts: ImageSearchOpts, - db: Pool, + pool: Pool, api_key: String, ) -> Result { - let db = db.get().await.map_err(map_bb8_err)?; + let db = pool.get().await.map_err(map_bb8_err)?; rate_limit!(&api_key, &db, image_limit, "image"); @@ -79,28 +78,25 @@ pub async fn search_image( debug!("Matching hash {}", num); - let results = { + let mut items = { if opts.search_type == Some(ImageSearchType::Force) { - image_query(&db, vec![num], 10).await.unwrap() + image_query(pool.clone(), vec![num], 10, Some(hash.as_bytes().to_vec())) + .await + .unwrap() } else { - let results = image_query(&db, vec![num], 0).await.unwrap(); + let results = image_query(pool.clone(), vec![num], 0, Some(hash.as_bytes().to_vec())) + .await + .unwrap(); if results.is_empty() && opts.search_type != Some(ImageSearchType::Exact) { - image_query(&db, vec![num], 10).await.unwrap() + image_query(pool.clone(), vec![num], 10, Some(hash.as_bytes().to_vec())) + .await + .unwrap() } else { results } } }; - let mut items = Vec::with_capacity(results.len()); - - items.extend(extract_fa_rows(results.furaffinity, Some(&hash.as_bytes()))); - items.extend(extract_e621_rows(results.e621, Some(&hash.as_bytes()))); - items.extend(extract_twitter_rows( - results.twitter, - Some(&hash.as_bytes()), - )); - items.sort_by(|a, b| { a.distance .unwrap_or(u64::max_value()) @@ -116,11 +112,75 @@ pub async fn search_image( Ok(warp::reply::json(&similarity)) } +pub async fn stream_image( + form: warp::multipart::FormData, + pool: Pool, + api_key: String, +) -> Result { + let db = pool.get().await.map_err(map_bb8_err)?; + + rate_limit!(&api_key, &db, image_limit, "image", 2); + + use bytes::BufMut; + use futures_util::StreamExt; + let parts: Vec<_> = form.collect().await; + let mut parts = parts + .into_iter() + .map(|part| { + let part = part.unwrap(); + (part.name().to_string(), part) + }) + .collect::>(); + let image = parts.remove("image").unwrap(); + + let bytes = image + .stream() + .fold(bytes::BytesMut::new(), |mut b, data| { + b.put(data.unwrap()); + async move { b } + }) + .await; + + let hash = { + let hasher = crate::get_hasher(); + let image = image::load_from_memory(&bytes).unwrap(); + hasher.hash_image(&image) + }; + + let mut buf: [u8; 8] = [0; 8]; + buf.copy_from_slice(&hash.as_bytes()); + + let num = i64::from_be_bytes(buf); + + debug!("Stream matching hash {}", num); + + let exact_event_stream = + image_query_sync(pool.clone(), vec![num], 0, Some(hash.as_bytes().to_vec())) + .map(sse_matches); + + let close_event_stream = + image_query_sync(pool.clone(), vec![num], 10, Some(hash.as_bytes().to_vec())) + .map(sse_matches); + + let event_stream = futures::stream::select(exact_event_stream, close_event_stream); + + Ok(warp::sse::reply(event_stream)) +} + +fn sse_matches( + matches: Result, tokio_postgres::Error>, +) -> Result { + let items = matches.unwrap(); + + Ok(warp::sse::json(items)) +} + pub async fn search_hashes( opts: HashSearchOpts, db: Pool, api_key: String, ) -> Result { + let pool = db.clone(); let db = db.get().await.map_err(map_bb8_err)?; let hashes: Vec = opts @@ -136,14 +196,12 @@ pub async fn search_hashes( rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16); - let results = image_query(&db, hashes, 10) - .await - .map_err(|err| reject::custom(Error::from(err)))?; + let mut results = image_query_sync(pool, hashes.clone(), 10, None); + let mut matches = Vec::new(); - let mut matches = Vec::with_capacity(results.len()); - matches.extend(extract_fa_rows(results.furaffinity, None)); - matches.extend(extract_e621_rows(results.e621, None)); - matches.extend(extract_twitter_rows(results.twitter, None)); + while let Some(r) = results.recv().await { + matches.extend(r.map_err(|e| warp::reject::custom(Error::Postgres(e)))?); + } Ok(warp::reply::json(&matches)) } diff --git a/src/main.rs b/src/main.rs index ab9c7f8..00a3e1c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +#![recursion_limit = "256"] + use std::str::FromStr; mod filters; diff --git a/src/models.rs b/src/models.rs index d3c8d59..7506418 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,4 +1,6 @@ use crate::types::*; +use crate::utils::{extract_e621_rows, extract_fa_rows, extract_twitter_rows}; +use crate::Pool; pub type DB<'a> = &'a bb8::PooledConnection<'a, bb8_postgres::PostgresConnectionManager>; @@ -35,107 +37,128 @@ pub async fn lookup_api_key(key: &str, db: DB<'_>) -> Option { } } -pub struct ImageQueryResults { - pub furaffinity: Vec, - pub e621: Vec, - pub twitter: Vec, -} - -impl ImageQueryResults { - #[inline] - pub fn len(&self) -> usize { - self.furaffinity.len() + self.e621.len() + self.twitter.len() - } - - #[inline] - pub fn is_empty(&self) -> bool { - self.len() == 0 - } -} - pub async fn image_query( - db: DB<'_>, + pool: Pool, hashes: Vec, distance: i64, -) -> Result { - let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = - Vec::with_capacity(hashes.len() + 1); - params.insert(0, &distance); + hash: Option>, +) -> Result, tokio_postgres::Error> { + let mut results = image_query_sync(pool, hashes, distance, hash); + let mut matches = Vec::new(); - let mut fa_where_clause = Vec::with_capacity(hashes.len()); - let mut hash_where_clause = Vec::with_capacity(hashes.len()); - - for (idx, hash) in hashes.iter().enumerate() { - params.push(hash); - - fa_where_clause.push(format!(" hash_int <@ (${}, $1)", idx + 2)); - hash_where_clause.push(format!(" hash <@ (${}, $1)", idx + 2)); + while let Some(r) = results.recv().await { + matches.extend(r?); } - let hash_where_clause = hash_where_clause.join(" OR "); - let fa_query = format!( - "SELECT - submission.id, - submission.url, - submission.filename, - submission.file_id, - submission.hash, - submission.hash_int, - artist.name - FROM - submission - JOIN artist - ON artist.id = submission.artist_id - WHERE - {}", - fa_where_clause.join(" OR ") - ); - - let e621_query = format!( - "SELECT - e621.id, - e621.hash, - e621.data->>'file_url' url, - e621.data->>'md5' md5, - sources.list sources, - artists.list artists, - (e621.data->>'md5') || '.' || (e621.data->>'file_ext') filename - FROM - e621, - LATERAL ( - SELECT array_agg(s) list - FROM jsonb_array_elements_text(data->'sources') s - ) sources, - LATERAL ( - SELECT array_agg(s) list - FROM jsonb_array_elements_text(data->'artist') s - ) artists - WHERE - {}", - &hash_where_clause - ); - - let twitter_query = format!( - "SELECT - twitter_view.id, - twitter_view.artists, - twitter_view.url, - twitter_view.hash - FROM - twitter_view - WHERE - {}", - &hash_where_clause - ); - - let furaffinity = db.query::(&*fa_query, ¶ms); - let e621 = db.query::(&*e621_query, ¶ms); - let twitter = db.query::(&*twitter_query, ¶ms); - - let results = futures::future::join3(furaffinity, e621, twitter).await; - Ok(ImageQueryResults { - furaffinity: results.0?, - e621: results.1?, - twitter: results.2?, - }) + Ok(matches) +} + +pub fn image_query_sync( + pool: Pool, + hashes: Vec, + distance: i64, + hash: Option>, +) -> tokio::sync::mpsc::Receiver, tokio_postgres::Error>> { + use futures_util::FutureExt; + + let (mut tx, rx) = tokio::sync::mpsc::channel(3); + + tokio::spawn(async move { + let db = pool.get().await.unwrap(); + + let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = + Vec::with_capacity(hashes.len() + 1); + params.insert(0, &distance); + + let mut fa_where_clause = Vec::with_capacity(hashes.len()); + let mut hash_where_clause = Vec::with_capacity(hashes.len()); + + for (idx, hash) in hashes.iter().enumerate() { + params.push(hash); + + fa_where_clause.push(format!(" hash_int <@ (${}, $1)", idx + 2)); + hash_where_clause.push(format!(" hash <@ (${}, $1)", idx + 2)); + } + let hash_where_clause = hash_where_clause.join(" OR "); + + let fa_query = format!( + "SELECT + submission.id, + submission.url, + submission.filename, + submission.file_id, + submission.hash, + submission.hash_int, + artist.name + FROM + submission + JOIN artist + ON artist.id = submission.artist_id + WHERE + {}", + fa_where_clause.join(" OR ") + ); + + let e621_query = format!( + "SELECT + e621.id, + e621.hash, + e621.data->>'file_url' url, + e621.data->>'md5' md5, + sources.list sources, + artists.list artists, + (e621.data->>'md5') || '.' || (e621.data->>'file_ext') filename + FROM + e621, + LATERAL ( + SELECT array_agg(s) list + FROM jsonb_array_elements_text(data->'sources') s + ) sources, + LATERAL ( + SELECT array_agg(s) list + FROM jsonb_array_elements_text(data->'artist') s + ) artists + WHERE + {}", + &hash_where_clause + ); + + let twitter_query = format!( + "SELECT + twitter_view.id, + twitter_view.artists, + twitter_view.url, + twitter_view.hash + FROM + twitter_view + WHERE + {}", + &hash_where_clause + ); + + let mut furaffinity = Box::pin(db.query::(&*fa_query, ¶ms).fuse()); + let mut e621 = Box::pin(db.query::(&*e621_query, ¶ms).fuse()); + let mut twitter = Box::pin(db.query::(&*twitter_query, ¶ms).fuse()); + + #[allow(clippy::unnecessary_mut_passed)] + loop { + futures::select! { + fa = furaffinity => { + let rows = fa.map(|rows| extract_fa_rows(rows, hash.as_deref()).into_iter().collect()); + tx.send(rows).await.unwrap(); + } + e = e621 => { + let rows = e.map(|rows| extract_e621_rows(rows, hash.as_deref()).into_iter().collect()); + tx.send(rows).await.unwrap(); + } + t = twitter => { + let rows = t.map(|rows| extract_twitter_rows(rows, hash.as_deref()).into_iter().collect()); + tx.send(rows).await.unwrap(); + } + complete => break, + } + } + }); + + rx }