diff --git a/src/handlers.rs b/src/handlers.rs index 3dc66cf..84d653a 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -1,6 +1,6 @@ use crate::models::image_query; use crate::types::*; -use crate::utils::{extract_e621_rows, extract_fa_rows}; +use crate::utils::{extract_e621_rows, extract_fa_rows, extract_twitter_rows}; use crate::{rate_limit, Pool}; use log::{debug, info}; use warp::{reject, Rejection, Reply}; @@ -79,25 +79,27 @@ pub async fn search_image( debug!("Matching hash {}", num); - let (fa_results, e621_results) = { + let results = { if opts.search_type == Some(ImageSearchType::Force) { image_query(&db, vec![num], 10).await.unwrap() } else { - let (fa_results, e621_results) = image_query(&db, vec![num], 0).await.unwrap(); - if fa_results.len() + e621_results.len() == 0 - && opts.search_type != Some(ImageSearchType::Exact) - { + let results = image_query(&db, vec![num], 0).await.unwrap(); + if results.is_empty() && opts.search_type != Some(ImageSearchType::Exact) { image_query(&db, vec![num], 10).await.unwrap() } else { - (fa_results, e621_results) + results } } }; - let mut items = Vec::with_capacity(fa_results.len() + e621_results.len()); + let mut items = Vec::with_capacity(results.len()); - items.extend(extract_fa_rows(fa_results, Some(&hash.as_bytes()))); - items.extend(extract_e621_rows(e621_results, Some(&hash.as_bytes()))); + items.extend(extract_fa_rows(results.furaffinity, Some(&hash.as_bytes()))); + items.extend(extract_e621_rows(results.e621, Some(&hash.as_bytes()))); + items.extend(extract_twitter_rows( + results.twitter, + Some(&hash.as_bytes()), + )); items.sort_by(|a, b| { a.distance @@ -124,6 +126,7 @@ pub async fn search_hashes( let hashes: Vec = opts .hashes .split(',') + .take(10) .filter_map(|hash| hash.parse::().ok()) .collect(); @@ -133,13 +136,14 @@ pub async fn search_hashes( rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16); - let (fa_matches, e621_matches) = image_query(&db, hashes, 10) + let results = image_query(&db, hashes, 10) .await .map_err(|err| reject::custom(Error::from(err)))?; - let mut matches = Vec::with_capacity(fa_matches.len() + e621_matches.len()); - matches.extend(extract_fa_rows(fa_matches, None)); - matches.extend(extract_e621_rows(e621_matches, None)); + let mut matches = Vec::with_capacity(results.len()); + matches.extend(extract_fa_rows(results.furaffinity, None)); + matches.extend(extract_e621_rows(results.e621, None)); + matches.extend(extract_twitter_rows(results.twitter, None)); Ok(warp::reply::json(&matches)) } @@ -190,6 +194,7 @@ pub async fn search_file( .into_iter() .map(|row| File { id: row.get("id"), + id_str: row.get::<&str, i32>("id").to_string(), url: row.get("url"), filename: row.get("filename"), artists: row diff --git a/src/models.rs b/src/models.rs index ed90012..d3c8d59 100644 --- a/src/models.rs +++ b/src/models.rs @@ -35,24 +35,43 @@ pub async fn lookup_api_key(key: &str, db: DB<'_>) -> Option { } } +pub struct ImageQueryResults { + pub furaffinity: Vec, + pub e621: Vec, + pub twitter: Vec, +} + +impl ImageQueryResults { + #[inline] + pub fn len(&self) -> usize { + self.furaffinity.len() + self.e621.len() + self.twitter.len() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + pub async fn image_query( db: DB<'_>, hashes: Vec, distance: i64, -) -> Result<(Vec, Vec), tokio_postgres::Error> { +) -> Result { let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = Vec::with_capacity(hashes.len() + 1); params.insert(0, &distance); let mut fa_where_clause = Vec::with_capacity(hashes.len()); - let mut e621_where_clause = Vec::with_capacity(hashes.len()); + let mut hash_where_clause = Vec::with_capacity(hashes.len()); for (idx, hash) in hashes.iter().enumerate() { params.push(hash); fa_where_clause.push(format!(" hash_int <@ (${}, $1)", idx + 2)); - e621_where_clause.push(format!(" hash <@ (${}, $1)", idx + 2)); + hash_where_clause.push(format!(" hash <@ (${}, $1)", idx + 2)); } + let hash_where_clause = hash_where_clause.join(" OR "); let fa_query = format!( "SELECT @@ -93,12 +112,30 @@ pub async fn image_query( ) artists WHERE {}", - e621_where_clause.join(" OR ") + &hash_where_clause ); - let fa = db.query::(&*fa_query, ¶ms); - let e621 = db.query::(&*e621_query, ¶ms); + let twitter_query = format!( + "SELECT + twitter_view.id, + twitter_view.artists, + twitter_view.url, + twitter_view.hash + FROM + twitter_view + WHERE + {}", + &hash_where_clause + ); - let results = futures::future::join(fa, e621).await; - Ok((results.0?, results.1?)) + let furaffinity = db.query::(&*fa_query, ¶ms); + let e621 = db.query::(&*e621_query, ¶ms); + let twitter = db.query::(&*twitter_query, ¶ms); + + let results = futures::future::join3(furaffinity, e621, twitter).await; + Ok(ImageQueryResults { + furaffinity: results.0?, + e621: results.1?, + twitter: results.2?, + }) } diff --git a/src/types.rs b/src/types.rs index d339d01..330917d 100644 --- a/src/types.rs +++ b/src/types.rs @@ -25,7 +25,8 @@ pub enum RateLimit { /// A general type for every file. #[derive(Debug, Default, Serialize)] pub struct File { - pub id: i32, + pub id: i64, + pub id_str: String, pub url: String, pub filename: String, pub artists: Option>, @@ -46,6 +47,7 @@ pub enum SiteInfo { FurAffinity(FurAffinityFile), #[serde(rename = "e621")] E621(E621File), + Twitter, } /// Information about a file hosted on FurAffinity. diff --git a/src/utils.rs b/src/utils.rs index 85eb3c1..2495c66 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -76,7 +76,8 @@ pub fn extract_fa_rows<'a>( let dbbytes: Vec = row.get("hash"); File { - id: row.get("id"), + id: row.get::<&str, i32>("id") as i64, + id_str: row.get::<&str, i32>("id").to_string(), url: row.get("url"), filename: row.get("filename"), hash: row.get("hash_int"), @@ -100,7 +101,8 @@ pub fn extract_e621_rows<'a>( let dbbytes = dbhash.to_be_bytes(); File { - id: row.get("id"), + id: row.get::<&str, i32>("id") as i64, + id_str: row.get::<&str, i32>("id").to_string(), url: row.get("url"), hash: Some(dbhash), distance: hash @@ -115,3 +117,37 @@ pub fn extract_e621_rows<'a>( } }) } + +pub fn extract_twitter_rows<'a>( + rows: Vec, + hash: Option<&'a [u8]>, +) -> impl IntoIterator + 'a { + rows.into_iter().map(move |row| { + let dbhash: i64 = row.get("hash"); + let dbbytes = dbhash.to_be_bytes(); + + let url: String = row.get("url"); + + let filename = url + .split('/') + .last() + .unwrap() + .split(':') + .next() + .unwrap() + .to_string(); + + File { + id: row.get("id"), + id_str: row.get::<&str, i64>("id").to_string(), + url, + hash: Some(dbhash), + distance: hash + .map(|hash| hamming::distance_fast(&dbbytes, &hash).ok()) + .flatten(), + site_info: Some(SiteInfo::Twitter), + artists: row.get("artists"), + filename, + } + }) +}