From cfd05672dd55b7aafb487f068068a1191809a765 Mon Sep 17 00:00:00 2001 From: Syfaro Date: Tue, 21 Jan 2020 19:04:28 -0600 Subject: [PATCH] Ability to search by hashes. --- src/filters.rs | 13 ++++++- src/handlers.rs | 36 +++++++++++++++++-- src/models.rs | 91 ++++++++++++++++++++++++++++--------------------- src/types.rs | 5 +++ 4 files changed, 103 insertions(+), 42 deletions(-) diff --git a/src/filters.rs b/src/filters.rs index 25f6ca7..3ce48bf 100644 --- a/src/filters.rs +++ b/src/filters.rs @@ -4,7 +4,9 @@ use std::convert::Infallible; use warp::{Filter, Rejection, Reply}; pub fn search(db: Pool) -> impl Filter + Clone { - search_file(db.clone()).or(search_image(db)) + search_file(db.clone()) + .or(search_image(db.clone())) + .or(search_hashes(db)) } pub fn search_file(db: Pool) -> impl Filter + Clone { @@ -26,6 +28,15 @@ pub fn search_image(db: Pool) -> impl Filter impl Filter + Clone { + warp::path("hashes") + .and(warp::get()) + .and(warp::query::()) + .and(with_pool(db)) + .and(with_api_key()) + .and_then(handlers::search_hashes) +} + fn with_api_key() -> impl Filter + Clone { warp::header::("x-api-key") } diff --git a/src/handlers.rs b/src/handlers.rs index c543ac3..3dc66cf 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -81,13 +81,13 @@ pub async fn search_image( let (fa_results, e621_results) = { if opts.search_type == Some(ImageSearchType::Force) { - image_query(&db, num, 10).await.unwrap() + image_query(&db, vec![num], 10).await.unwrap() } else { - let (fa_results, e621_results) = image_query(&db, num, 0).await.unwrap(); + let (fa_results, e621_results) = image_query(&db, vec![num], 0).await.unwrap(); if fa_results.len() + e621_results.len() == 0 && opts.search_type != Some(ImageSearchType::Exact) { - image_query(&db, num, 10).await.unwrap() + image_query(&db, vec![num], 10).await.unwrap() } else { (fa_results, e621_results) } @@ -114,6 +114,36 @@ pub async fn search_image( Ok(warp::reply::json(&similarity)) } +pub async fn search_hashes( + opts: HashSearchOpts, + db: Pool, + api_key: String, +) -> Result { + let db = db.get().await.map_err(map_bb8_err)?; + + let hashes: Vec = opts + .hashes + .split(',') + .filter_map(|hash| hash.parse::().ok()) + .collect(); + + if hashes.is_empty() { + return Err(warp::reject::custom(Error::InvalidData)); + } + + rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16); + + let (fa_matches, e621_matches) = image_query(&db, hashes, 10) + .await + .map_err(|err| reject::custom(Error::from(err)))?; + + let mut matches = Vec::with_capacity(fa_matches.len() + e621_matches.len()); + matches.extend(extract_fa_rows(fa_matches, None)); + matches.extend(extract_e621_rows(e621_matches, None)); + + Ok(warp::reply::json(&matches)) +} + pub async fn search_file( opts: FileSearchOpts, db: Pool, diff --git a/src/models.rs b/src/models.rs index 1e260f3..ed90012 100644 --- a/src/models.rs +++ b/src/models.rs @@ -37,53 +37,68 @@ pub async fn lookup_api_key(key: &str, db: DB<'_>) -> Option { pub async fn image_query( db: DB<'_>, - num: i64, + hashes: Vec, distance: i64, ) -> Result<(Vec, Vec), tokio_postgres::Error> { - let params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = vec![&num, &distance]; + let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = + Vec::with_capacity(hashes.len() + 1); + params.insert(0, &distance); - let fa = db.query( + let mut fa_where_clause = Vec::with_capacity(hashes.len()); + let mut e621_where_clause = Vec::with_capacity(hashes.len()); + + for (idx, hash) in hashes.iter().enumerate() { + params.push(hash); + + fa_where_clause.push(format!(" hash_int <@ (${}, $1)", idx + 2)); + e621_where_clause.push(format!(" hash <@ (${}, $1)", idx + 2)); + } + + let fa_query = format!( "SELECT - submission.id, - submission.url, - submission.filename, - submission.file_id, - submission.hash, - submission.hash_int, - artist.name - FROM - submission - JOIN artist - ON artist.id = submission.artist_id - WHERE - hash_int <@ ($1, $2)", - ¶ms, + submission.id, + submission.url, + submission.filename, + submission.file_id, + submission.hash, + submission.hash_int, + artist.name + FROM + submission + JOIN artist + ON artist.id = submission.artist_id + WHERE + {}", + fa_where_clause.join(" OR ") ); - let e621 = db.query( + let e621_query = format!( "SELECT - e621.id, - e621.hash, - e621.data->>'file_url' url, - e621.data->>'md5' md5, - sources.list sources, - artists.list artists, - (e621.data->>'md5') || '.' || (e621.data->>'file_ext') filename - FROM - e621, - LATERAL ( - SELECT array_agg(s) list - FROM jsonb_array_elements_text(data->'sources') s - ) sources, - LATERAL ( - SELECT array_agg(s) list - FROM jsonb_array_elements_text(data->'artist') s - ) artists - WHERE - hash <@ ($1, $2)", - ¶ms, + e621.id, + e621.hash, + e621.data->>'file_url' url, + e621.data->>'md5' md5, + sources.list sources, + artists.list artists, + (e621.data->>'md5') || '.' || (e621.data->>'file_ext') filename + FROM + e621, + LATERAL ( + SELECT array_agg(s) list + FROM jsonb_array_elements_text(data->'sources') s + ) sources, + LATERAL ( + SELECT array_agg(s) list + FROM jsonb_array_elements_text(data->'artist') s + ) artists + WHERE + {}", + e621_where_clause.join(" OR ") ); + let fa = db.query::(&*fa_query, ¶ms); + let e621 = db.query::(&*e621_query, ¶ms); + let results = futures::future::join(fa, e621).await; Ok((results.0?, results.1?)) } diff --git a/src/types.rs b/src/types.rs index 1aace91..d339d01 100644 --- a/src/types.rs +++ b/src/types.rs @@ -93,3 +93,8 @@ pub struct ErrorMessage { pub code: u16, pub message: String, } + +#[derive(Deserialize)] +pub struct HashSearchOpts { + pub hashes: String, +}