diff --git a/src/filters.rs b/src/filters.rs index b81221e..25f6ca7 100644 --- a/src/filters.rs +++ b/src/filters.rs @@ -20,6 +20,7 @@ pub fn search_image(db: Pool) -> impl Filter()) .and(with_pool(db)) .and(with_api_key()) .and_then(handlers::search_image) diff --git a/src/handlers.rs b/src/handlers.rs index 217196f..c543ac3 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -1,7 +1,8 @@ +use crate::models::image_query; use crate::types::*; use crate::utils::{extract_e621_rows, extract_fa_rows}; use crate::{rate_limit, Pool}; -use log::{info, debug}; +use log::{debug, info}; use warp::{reject, Rejection, Reply}; fn map_bb8_err(err: bb8::RunError) -> Rejection { @@ -37,6 +38,7 @@ impl warp::reject::Reject for Error {} pub async fn search_image( form: warp::multipart::FormData, + opts: ImageSearchOpts, db: Pool, api_key: String, ) -> Result { @@ -77,52 +79,20 @@ pub async fn search_image( debug!("Matching hash {}", num); - let params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = vec![&num]; - - let fa = db.query( - "SELECT - submission.id, - submission.url, - submission.filename, - submission.file_id, - submission.hash, - submission.hash_int, - artist.name - FROM - submission - JOIN artist - ON artist.id = submission.artist_id - WHERE - hash_int <@ ($1, 10)", - ¶ms, - ); - - let e621 = db.query( - "SELECT - e621.id, - e621.hash, - e621.data->>'file_url' url, - e621.data->>'md5' md5, - sources.list sources, - artists.list artists, - (e621.data->>'md5') || '.' || (e621.data->>'file_ext') filename - FROM - e621, - LATERAL ( - SELECT array_agg(s) list - FROM jsonb_array_elements_text(data->'sources') s - ) sources, - LATERAL ( - SELECT array_agg(s) list - FROM jsonb_array_elements_text(data->'artist') s - ) artists - WHERE - hash <@ ($1, 10)", - ¶ms, - ); - - let results = futures::future::join(fa, e621).await; - let (fa_results, e621_results) = (results.0.unwrap(), results.1.unwrap()); + let (fa_results, e621_results) = { + if opts.search_type == Some(ImageSearchType::Force) { + image_query(&db, num, 10).await.unwrap() + } else { + let (fa_results, e621_results) = image_query(&db, num, 0).await.unwrap(); + if fa_results.len() + e621_results.len() == 0 + && opts.search_type != Some(ImageSearchType::Exact) + { + image_query(&db, num, 10).await.unwrap() + } else { + (fa_results, e621_results) + } + } + }; let mut items = Vec::with_capacity(fa_results.len() + e621_results.len()); diff --git a/src/main.rs b/src/main.rs index 5a3ba44..ab9c7f8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -27,9 +27,12 @@ async fn main() { let log = warp::log("fuzzysearch"); let cors = warp::cors() .allow_any_origin() + .allow_headers(vec!["x-api-key"]) .allow_methods(vec!["GET", "POST"]); - let api = filters::search(db_pool); + let options = warp::options().map(|| "✓"); + + let api = options.or(filters::search(db_pool)); let routes = api .or(warp::path::end() .map(|| warp::redirect(warp::http::Uri::from_static("https://fuzzysearch.net")))) diff --git a/src/models.rs b/src/models.rs index 1b4aa1f..1e260f3 100644 --- a/src/models.rs +++ b/src/models.rs @@ -34,3 +34,56 @@ pub async fn lookup_api_key(key: &str, db: DB<'_>) -> Option { _ => None, } } + +pub async fn image_query( + db: DB<'_>, + num: i64, + distance: i64, +) -> Result<(Vec, Vec), tokio_postgres::Error> { + let params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = vec![&num, &distance]; + + let fa = db.query( + "SELECT + submission.id, + submission.url, + submission.filename, + submission.file_id, + submission.hash, + submission.hash_int, + artist.name + FROM + submission + JOIN artist + ON artist.id = submission.artist_id + WHERE + hash_int <@ ($1, $2)", + ¶ms, + ); + + let e621 = db.query( + "SELECT + e621.id, + e621.hash, + e621.data->>'file_url' url, + e621.data->>'md5' md5, + sources.list sources, + artists.list artists, + (e621.data->>'md5') || '.' || (e621.data->>'file_ext') filename + FROM + e621, + LATERAL ( + SELECT array_agg(s) list + FROM jsonb_array_elements_text(data->'sources') s + ) sources, + LATERAL ( + SELECT array_agg(s) list + FROM jsonb_array_elements_text(data->'artist') s + ) artists + WHERE + hash <@ ($1, $2)", + ¶ms, + ); + + let results = futures::future::join(fa, e621).await; + Ok((results.0?, results.1?)) +} diff --git a/src/types.rs b/src/types.rs index bae316f..1aace91 100644 --- a/src/types.rs +++ b/src/types.rs @@ -68,6 +68,20 @@ pub struct FileSearchOpts { pub url: Option, } +#[derive(Debug, Deserialize)] +pub struct ImageSearchOpts { + #[serde(rename = "type")] + pub search_type: Option, +} + +#[derive(Debug, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ImageSearchType { + Close, + Exact, + Force, +} + #[derive(Debug, Serialize)] pub struct ImageSimilarity { pub hash: i64,