Ability to stream responses for faster updates.

This commit is contained in:
Syfaro 2020-01-24 00:29:13 -06:00
parent cad0016522
commit be5e1a9b97
4 changed files with 217 additions and 122 deletions

View File

@ -6,7 +6,8 @@ use warp::{Filter, Rejection, Reply};
pub fn search(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
search_file(db.clone())
.or(search_image(db.clone()))
.or(search_hashes(db))
.or(search_hashes(db.clone()))
.or(stream_search_image(db))
}
pub fn search_file(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
@ -37,6 +38,17 @@ pub fn search_hashes(db: Pool) -> impl Filter<Extract = impl Reply, Error = Reje
.and_then(handlers::search_hashes)
}
pub fn stream_search_image(
db: Pool,
) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
warp::path("stream")
.and(warp::post())
.and(warp::multipart::form().max_length(1024 * 1024 * 10))
.and(with_pool(db))
.and(with_api_key())
.and_then(handlers::stream_image)
}
fn with_api_key() -> impl Filter<Extract = (String,), Error = Rejection> + Clone {
warp::header::<String>("x-api-key")
}

View File

@ -1,6 +1,5 @@
use crate::models::image_query;
use crate::models::{image_query, image_query_sync};
use crate::types::*;
use crate::utils::{extract_e621_rows, extract_fa_rows, extract_twitter_rows};
use crate::{rate_limit, Pool};
use log::{debug, info};
use warp::{reject, Rejection, Reply};
@ -39,10 +38,10 @@ impl warp::reject::Reject for Error {}
pub async fn search_image(
form: warp::multipart::FormData,
opts: ImageSearchOpts,
db: Pool,
pool: Pool,
api_key: String,
) -> Result<impl Reply, Rejection> {
let db = db.get().await.map_err(map_bb8_err)?;
let db = pool.get().await.map_err(map_bb8_err)?;
rate_limit!(&api_key, &db, image_limit, "image");
@ -79,28 +78,25 @@ pub async fn search_image(
debug!("Matching hash {}", num);
let results = {
let mut items = {
if opts.search_type == Some(ImageSearchType::Force) {
image_query(&db, vec![num], 10).await.unwrap()
image_query(pool.clone(), vec![num], 10, Some(hash.as_bytes().to_vec()))
.await
.unwrap()
} else {
let results = image_query(&db, vec![num], 0).await.unwrap();
let results = image_query(pool.clone(), vec![num], 0, Some(hash.as_bytes().to_vec()))
.await
.unwrap();
if results.is_empty() && opts.search_type != Some(ImageSearchType::Exact) {
image_query(&db, vec![num], 10).await.unwrap()
image_query(pool.clone(), vec![num], 10, Some(hash.as_bytes().to_vec()))
.await
.unwrap()
} else {
results
}
}
};
let mut items = Vec::with_capacity(results.len());
items.extend(extract_fa_rows(results.furaffinity, Some(&hash.as_bytes())));
items.extend(extract_e621_rows(results.e621, Some(&hash.as_bytes())));
items.extend(extract_twitter_rows(
results.twitter,
Some(&hash.as_bytes()),
));
items.sort_by(|a, b| {
a.distance
.unwrap_or(u64::max_value())
@ -116,11 +112,75 @@ pub async fn search_image(
Ok(warp::reply::json(&similarity))
}
pub async fn stream_image(
form: warp::multipart::FormData,
pool: Pool,
api_key: String,
) -> Result<impl Reply, Rejection> {
let db = pool.get().await.map_err(map_bb8_err)?;
rate_limit!(&api_key, &db, image_limit, "image", 2);
use bytes::BufMut;
use futures_util::StreamExt;
let parts: Vec<_> = form.collect().await;
let mut parts = parts
.into_iter()
.map(|part| {
let part = part.unwrap();
(part.name().to_string(), part)
})
.collect::<std::collections::HashMap<_, _>>();
let image = parts.remove("image").unwrap();
let bytes = image
.stream()
.fold(bytes::BytesMut::new(), |mut b, data| {
b.put(data.unwrap());
async move { b }
})
.await;
let hash = {
let hasher = crate::get_hasher();
let image = image::load_from_memory(&bytes).unwrap();
hasher.hash_image(&image)
};
let mut buf: [u8; 8] = [0; 8];
buf.copy_from_slice(&hash.as_bytes());
let num = i64::from_be_bytes(buf);
debug!("Stream matching hash {}", num);
let exact_event_stream =
image_query_sync(pool.clone(), vec![num], 0, Some(hash.as_bytes().to_vec()))
.map(sse_matches);
let close_event_stream =
image_query_sync(pool.clone(), vec![num], 10, Some(hash.as_bytes().to_vec()))
.map(sse_matches);
let event_stream = futures::stream::select(exact_event_stream, close_event_stream);
Ok(warp::sse::reply(event_stream))
}
fn sse_matches(
matches: Result<Vec<File>, tokio_postgres::Error>,
) -> Result<impl warp::sse::ServerSentEvent, core::convert::Infallible> {
let items = matches.unwrap();
Ok(warp::sse::json(items))
}
pub async fn search_hashes(
opts: HashSearchOpts,
db: Pool,
api_key: String,
) -> Result<impl Reply, Rejection> {
let pool = db.clone();
let db = db.get().await.map_err(map_bb8_err)?;
let hashes: Vec<i64> = opts
@ -136,14 +196,12 @@ pub async fn search_hashes(
rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16);
let results = image_query(&db, hashes, 10)
.await
.map_err(|err| reject::custom(Error::from(err)))?;
let mut results = image_query_sync(pool, hashes.clone(), 10, None);
let mut matches = Vec::new();
let mut matches = Vec::with_capacity(results.len());
matches.extend(extract_fa_rows(results.furaffinity, None));
matches.extend(extract_e621_rows(results.e621, None));
matches.extend(extract_twitter_rows(results.twitter, None));
while let Some(r) = results.recv().await {
matches.extend(r.map_err(|e| warp::reject::custom(Error::Postgres(e)))?);
}
Ok(warp::reply::json(&matches))
}

View File

@ -1,3 +1,5 @@
#![recursion_limit = "256"]
use std::str::FromStr;
mod filters;

View File

@ -1,4 +1,6 @@
use crate::types::*;
use crate::utils::{extract_e621_rows, extract_fa_rows, extract_twitter_rows};
use crate::Pool;
pub type DB<'a> =
&'a bb8::PooledConnection<'a, bb8_postgres::PostgresConnectionManager<tokio_postgres::NoTls>>;
@ -35,107 +37,128 @@ pub async fn lookup_api_key(key: &str, db: DB<'_>) -> Option<ApiKey> {
}
}
pub struct ImageQueryResults {
pub furaffinity: Vec<tokio_postgres::Row>,
pub e621: Vec<tokio_postgres::Row>,
pub twitter: Vec<tokio_postgres::Row>,
}
impl ImageQueryResults {
#[inline]
pub fn len(&self) -> usize {
self.furaffinity.len() + self.e621.len() + self.twitter.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub async fn image_query(
db: DB<'_>,
pool: Pool,
hashes: Vec<i64>,
distance: i64,
) -> Result<ImageQueryResults, tokio_postgres::Error> {
let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
Vec::with_capacity(hashes.len() + 1);
params.insert(0, &distance);
hash: Option<Vec<u8>>,
) -> Result<Vec<File>, tokio_postgres::Error> {
let mut results = image_query_sync(pool, hashes, distance, hash);
let mut matches = Vec::new();
let mut fa_where_clause = Vec::with_capacity(hashes.len());
let mut hash_where_clause = Vec::with_capacity(hashes.len());
for (idx, hash) in hashes.iter().enumerate() {
params.push(hash);
fa_where_clause.push(format!(" hash_int <@ (${}, $1)", idx + 2));
hash_where_clause.push(format!(" hash <@ (${}, $1)", idx + 2));
while let Some(r) = results.recv().await {
matches.extend(r?);
}
let hash_where_clause = hash_where_clause.join(" OR ");
let fa_query = format!(
"SELECT
submission.id,
submission.url,
submission.filename,
submission.file_id,
submission.hash,
submission.hash_int,
artist.name
FROM
submission
JOIN artist
ON artist.id = submission.artist_id
WHERE
{}",
fa_where_clause.join(" OR ")
);
let e621_query = format!(
"SELECT
e621.id,
e621.hash,
e621.data->>'file_url' url,
e621.data->>'md5' md5,
sources.list sources,
artists.list artists,
(e621.data->>'md5') || '.' || (e621.data->>'file_ext') filename
FROM
e621,
LATERAL (
SELECT array_agg(s) list
FROM jsonb_array_elements_text(data->'sources') s
) sources,
LATERAL (
SELECT array_agg(s) list
FROM jsonb_array_elements_text(data->'artist') s
) artists
WHERE
{}",
&hash_where_clause
);
let twitter_query = format!(
"SELECT
twitter_view.id,
twitter_view.artists,
twitter_view.url,
twitter_view.hash
FROM
twitter_view
WHERE
{}",
&hash_where_clause
);
let furaffinity = db.query::<str>(&*fa_query, &params);
let e621 = db.query::<str>(&*e621_query, &params);
let twitter = db.query::<str>(&*twitter_query, &params);
let results = futures::future::join3(furaffinity, e621, twitter).await;
Ok(ImageQueryResults {
furaffinity: results.0?,
e621: results.1?,
twitter: results.2?,
})
Ok(matches)
}
pub fn image_query_sync(
pool: Pool,
hashes: Vec<i64>,
distance: i64,
hash: Option<Vec<u8>>,
) -> tokio::sync::mpsc::Receiver<Result<Vec<File>, tokio_postgres::Error>> {
use futures_util::FutureExt;
let (mut tx, rx) = tokio::sync::mpsc::channel(3);
tokio::spawn(async move {
let db = pool.get().await.unwrap();
let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
Vec::with_capacity(hashes.len() + 1);
params.insert(0, &distance);
let mut fa_where_clause = Vec::with_capacity(hashes.len());
let mut hash_where_clause = Vec::with_capacity(hashes.len());
for (idx, hash) in hashes.iter().enumerate() {
params.push(hash);
fa_where_clause.push(format!(" hash_int <@ (${}, $1)", idx + 2));
hash_where_clause.push(format!(" hash <@ (${}, $1)", idx + 2));
}
let hash_where_clause = hash_where_clause.join(" OR ");
let fa_query = format!(
"SELECT
submission.id,
submission.url,
submission.filename,
submission.file_id,
submission.hash,
submission.hash_int,
artist.name
FROM
submission
JOIN artist
ON artist.id = submission.artist_id
WHERE
{}",
fa_where_clause.join(" OR ")
);
let e621_query = format!(
"SELECT
e621.id,
e621.hash,
e621.data->>'file_url' url,
e621.data->>'md5' md5,
sources.list sources,
artists.list artists,
(e621.data->>'md5') || '.' || (e621.data->>'file_ext') filename
FROM
e621,
LATERAL (
SELECT array_agg(s) list
FROM jsonb_array_elements_text(data->'sources') s
) sources,
LATERAL (
SELECT array_agg(s) list
FROM jsonb_array_elements_text(data->'artist') s
) artists
WHERE
{}",
&hash_where_clause
);
let twitter_query = format!(
"SELECT
twitter_view.id,
twitter_view.artists,
twitter_view.url,
twitter_view.hash
FROM
twitter_view
WHERE
{}",
&hash_where_clause
);
let mut furaffinity = Box::pin(db.query::<str>(&*fa_query, &params).fuse());
let mut e621 = Box::pin(db.query::<str>(&*e621_query, &params).fuse());
let mut twitter = Box::pin(db.query::<str>(&*twitter_query, &params).fuse());
#[allow(clippy::unnecessary_mut_passed)]
loop {
futures::select! {
fa = furaffinity => {
let rows = fa.map(|rows| extract_fa_rows(rows, hash.as_deref()).into_iter().collect());
tx.send(rows).await.unwrap();
}
e = e621 => {
let rows = e.map(|rows| extract_e621_rows(rows, hash.as_deref()).into_iter().collect());
tx.send(rows).await.unwrap();
}
t = twitter => {
let rows = t.map(|rows| extract_twitter_rows(rows, hash.as_deref()).into_iter().collect());
tx.send(rows).await.unwrap();
}
complete => break,
}
}
});
rx
}