Ability to search by hashes.

This commit is contained in:
Syfaro 2020-01-21 19:04:28 -06:00
parent 658a023334
commit cfd05672dd
4 changed files with 103 additions and 42 deletions

View File

@ -4,7 +4,9 @@ use std::convert::Infallible;
use warp::{Filter, Rejection, Reply}; use warp::{Filter, Rejection, Reply};
pub fn search(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone { pub fn search(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
search_file(db.clone()).or(search_image(db)) search_file(db.clone())
.or(search_image(db.clone()))
.or(search_hashes(db))
} }
pub fn search_file(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone { pub fn search_file(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
@ -26,6 +28,15 @@ pub fn search_image(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejec
.and_then(handlers::search_image) .and_then(handlers::search_image)
} }
pub fn search_hashes(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
warp::path("hashes")
.and(warp::get())
.and(warp::query::<HashSearchOpts>())
.and(with_pool(db))
.and(with_api_key())
.and_then(handlers::search_hashes)
}
fn with_api_key() -> impl Filter<Extract = (String,), Error = Rejection> + Clone { fn with_api_key() -> impl Filter<Extract = (String,), Error = Rejection> + Clone {
warp::header::<String>("x-api-key") warp::header::<String>("x-api-key")
} }

View File

@ -81,13 +81,13 @@ pub async fn search_image(
let (fa_results, e621_results) = { let (fa_results, e621_results) = {
if opts.search_type == Some(ImageSearchType::Force) { if opts.search_type == Some(ImageSearchType::Force) {
image_query(&db, num, 10).await.unwrap() image_query(&db, vec![num], 10).await.unwrap()
} else { } else {
let (fa_results, e621_results) = image_query(&db, num, 0).await.unwrap(); let (fa_results, e621_results) = image_query(&db, vec![num], 0).await.unwrap();
if fa_results.len() + e621_results.len() == 0 if fa_results.len() + e621_results.len() == 0
&& opts.search_type != Some(ImageSearchType::Exact) && opts.search_type != Some(ImageSearchType::Exact)
{ {
image_query(&db, num, 10).await.unwrap() image_query(&db, vec![num], 10).await.unwrap()
} else { } else {
(fa_results, e621_results) (fa_results, e621_results)
} }
@ -114,6 +114,36 @@ pub async fn search_image(
Ok(warp::reply::json(&similarity)) Ok(warp::reply::json(&similarity))
} }
pub async fn search_hashes(
opts: HashSearchOpts,
db: Pool,
api_key: String,
) -> Result<impl Reply, Rejection> {
let db = db.get().await.map_err(map_bb8_err)?;
let hashes: Vec<i64> = opts
.hashes
.split(',')
.filter_map(|hash| hash.parse::<i64>().ok())
.collect();
if hashes.is_empty() {
return Err(warp::reject::custom(Error::InvalidData));
}
rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16);
let (fa_matches, e621_matches) = image_query(&db, hashes, 10)
.await
.map_err(|err| reject::custom(Error::from(err)))?;
let mut matches = Vec::with_capacity(fa_matches.len() + e621_matches.len());
matches.extend(extract_fa_rows(fa_matches, None));
matches.extend(extract_e621_rows(e621_matches, None));
Ok(warp::reply::json(&matches))
}
pub async fn search_file( pub async fn search_file(
opts: FileSearchOpts, opts: FileSearchOpts,
db: Pool, db: Pool,

View File

@ -37,12 +37,24 @@ pub async fn lookup_api_key(key: &str, db: DB<'_>) -> Option<ApiKey> {
pub async fn image_query( pub async fn image_query(
db: DB<'_>, db: DB<'_>,
num: i64, hashes: Vec<i64>,
distance: i64, distance: i64,
) -> Result<(Vec<tokio_postgres::Row>, Vec<tokio_postgres::Row>), tokio_postgres::Error> { ) -> Result<(Vec<tokio_postgres::Row>, Vec<tokio_postgres::Row>), tokio_postgres::Error> {
let params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = vec![&num, &distance]; let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
Vec::with_capacity(hashes.len() + 1);
params.insert(0, &distance);
let fa = db.query( let mut fa_where_clause = Vec::with_capacity(hashes.len());
let mut e621_where_clause = Vec::with_capacity(hashes.len());
for (idx, hash) in hashes.iter().enumerate() {
params.push(hash);
fa_where_clause.push(format!(" hash_int <@ (${}, $1)", idx + 2));
e621_where_clause.push(format!(" hash <@ (${}, $1)", idx + 2));
}
let fa_query = format!(
"SELECT "SELECT
submission.id, submission.id,
submission.url, submission.url,
@ -56,11 +68,11 @@ pub async fn image_query(
JOIN artist JOIN artist
ON artist.id = submission.artist_id ON artist.id = submission.artist_id
WHERE WHERE
hash_int <@ ($1, $2)", {}",
&params, fa_where_clause.join(" OR ")
); );
let e621 = db.query( let e621_query = format!(
"SELECT "SELECT
e621.id, e621.id,
e621.hash, e621.hash,
@ -80,10 +92,13 @@ pub async fn image_query(
FROM jsonb_array_elements_text(data->'artist') s FROM jsonb_array_elements_text(data->'artist') s
) artists ) artists
WHERE WHERE
hash <@ ($1, $2)", {}",
&params, e621_where_clause.join(" OR ")
); );
let fa = db.query::<str>(&*fa_query, &params);
let e621 = db.query::<str>(&*e621_query, &params);
let results = futures::future::join(fa, e621).await; let results = futures::future::join(fa, e621).await;
Ok((results.0?, results.1?)) Ok((results.0?, results.1?))
} }

View File

@ -93,3 +93,8 @@ pub struct ErrorMessage {
pub code: u16, pub code: u16,
pub message: String, pub message: String,
} }
#[derive(Deserialize)]
pub struct HashSearchOpts {
pub hashes: String,
}