diff --git a/Cargo.lock b/Cargo.lock index f84d390..f98440f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1263,9 +1263,9 @@ checksum = "05842d0d43232b23ccb7060ecb0f0626922c21f30012e97b767b30afd4a5d4b9" [[package]] name = "hyper" -version = "0.14.6" +version = "0.14.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f006b8784cfb01fe7aa9c46f5f5cd4cf5c85a8c612a0653ec97642979062665" +checksum = "1e5f105c494081baa3bf9e200b279e27ec1623895cd504c7dbef8d0b080fcf54" dependencies = [ "bytes", "futures-channel", @@ -2394,9 +2394,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.4.5" +version = "1.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "957056ecddbeba1b26965114e191d2e8589ce74db242b6ea25fc4062427a5c19" +checksum = "2a26af418b574bd56588335b3a3659a65725d4e636eb1016c2f9e3b38c7cc759" dependencies = [ "aho-corasick", "memchr", diff --git a/fuzzysearch-webhook/src/main.rs b/fuzzysearch-webhook/src/main.rs index 353e611..183d456 100644 --- a/fuzzysearch-webhook/src/main.rs +++ b/fuzzysearch-webhook/src/main.rs @@ -49,10 +49,9 @@ fn main() { let data = job .args() - .into_iter() + .iter() .next() .ok_or(WebhookError::MissingData)? - .to_owned() .to_owned(); let value: fuzzysearch_common::types::WebHookData = serde_json::value::from_value(data)?; diff --git a/fuzzysearch/src/filters.rs b/fuzzysearch/src/filters.rs index f03a77f..f2e7dd4 100644 --- a/fuzzysearch/src/filters.rs +++ b/fuzzysearch/src/filters.rs @@ -10,7 +10,6 @@ pub fn search( ) -> impl Filter + Clone { search_image(db.clone(), tree.clone()) .or(search_hashes(db.clone(), tree.clone())) - .or(stream_search_image(db.clone(), tree.clone())) .or(search_file(db.clone())) .or(search_video(db.clone())) .or(check_handle(db.clone())) @@ -89,26 +88,6 @@ pub fn search_hashes( }) } -pub fn stream_search_image( - db: Pool, - tree: Tree, -) -> impl Filter + Clone { - warp::path("stream") - .and(warp::header::headers_cloned()) - .and(warp::post()) - .and(warp::multipart::form().max_length(1024 * 1024 * 10)) - .and(with_pool(db)) - .and(with_tree(tree)) - .and(with_api_key()) - .and_then(|headers, form, pool, tree, api_key| { - use tracing_opentelemetry::OpenTelemetrySpanExt; - - let span = tracing::info_span!("stream_search_image"); - span.set_parent(with_telem(headers)); - span.in_scope(|| handlers::stream_image(form, pool, tree, api_key).in_current_span()) - }) -} - pub fn search_video(db: Pool) -> impl Filter + Clone { warp::path("video") .and(warp::header::headers_cloned()) diff --git a/fuzzysearch/src/handlers.rs b/fuzzysearch/src/handlers.rs index b5f9d46..a67943a 100644 --- a/fuzzysearch/src/handlers.rs +++ b/fuzzysearch/src/handlers.rs @@ -5,7 +5,7 @@ use tracing::{span, warn}; use tracing_futures::Instrument; use warp::{Rejection, Reply}; -use crate::models::{image_query, image_query_sync}; +use crate::models::image_query; use crate::types::*; use crate::{early_return, rate_limit, Pool, Tree}; use fuzzysearch_common::types::{SearchResult, SiteInfo}; @@ -112,7 +112,7 @@ async fn get_field_bytes(form: warp::multipart::FormData, field: &str) -> bytes: } #[tracing::instrument(skip(form))] -async fn hash_input(form: warp::multipart::FormData) -> (i64, img_hash::ImageHash<[u8; 8]>) { +async fn hash_input(form: warp::multipart::FormData) -> i64 { let bytes = get_field_bytes(form, "image").await; let len = bytes.len(); @@ -131,7 +131,7 @@ async fn hash_input(form: warp::multipart::FormData) -> (i64, img_hash::ImageHas let mut buf: [u8; 8] = [0; 8]; buf.copy_from_slice(&hash.as_bytes()); - (i64::from_be_bytes(buf), hash) + i64::from_be_bytes(buf) } #[tracing::instrument(skip(form))] @@ -168,39 +168,21 @@ pub async fn search_image( let image_remaining = rate_limit!(&api_key, &db, image_limit, "image"); let hash_remaining = rate_limit!(&api_key, &db, hash_limit, "hash"); - let (num, hash) = hash_input(form).await; + let num = hash_input(form).await; let mut items = { if opts.search_type == Some(ImageSearchType::Force) { - image_query( - db.clone(), - tree.clone(), - vec![num], - 10, - Some(hash.as_bytes().to_vec()), - ) - .await - .unwrap() - } else { - let results = image_query( - db.clone(), - tree.clone(), - vec![num], - 0, - Some(hash.as_bytes().to_vec()), - ) - .await - .unwrap(); - if results.is_empty() && opts.search_type != Some(ImageSearchType::Exact) { - image_query( - db.clone(), - tree.clone(), - vec![num], - 10, - Some(hash.as_bytes().to_vec()), - ) + image_query(db.clone(), tree.clone(), vec![num], 10) .await .unwrap() + } else { + let results = image_query(db.clone(), tree.clone(), vec![num], 0) + .await + .unwrap(); + if results.is_empty() && opts.search_type != Some(ImageSearchType::Exact) { + image_query(db.clone(), tree.clone(), vec![num], 10) + .await + .unwrap() } else { results } @@ -235,43 +217,6 @@ pub async fn search_image( Ok(Box::new(resp)) } -pub async fn stream_image( - form: warp::multipart::FormData, - pool: Pool, - tree: Tree, - api_key: String, -) -> Result, Rejection> { - rate_limit!(&api_key, &pool, image_limit, "image", 2); - rate_limit!(&api_key, &pool, hash_limit, "hash"); - - let (num, hash) = hash_input(form).await; - - let mut query = image_query_sync( - pool.clone(), - tree, - vec![num], - 10, - Some(hash.as_bytes().to_vec()), - ); - - let event_stream = async_stream::stream! { - while let Some(result) = query.recv().await { - yield sse_matches(result); - } - }; - - Ok(Box::new(warp::sse::reply(event_stream))) -} - -#[allow(clippy::unnecessary_wraps)] -fn sse_matches( - matches: Result, sqlx::Error>, -) -> Result { - let items = matches.unwrap(); - - Ok(warp::sse::Event::default().json_data(items).unwrap()) -} - pub async fn search_hashes( opts: HashSearchOpts, db: Pool, @@ -293,18 +238,8 @@ pub async fn search_hashes( let image_remaining = rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16); - let mut results = image_query_sync( - pool, - tree, - hashes.clone(), - opts.distance.unwrap_or(10), - None, - ); - let mut matches = Vec::new(); - - while let Some(r) = results.recv().await { - matches.extend(early_return!(r)); - } + let results = + early_return!(image_query(pool, tree, hashes.clone(), opts.distance.unwrap_or(10),).await); let resp = warp::http::Response::builder() .header("x-rate-limit-total-image", image_remaining.1.to_string()) @@ -313,7 +248,7 @@ pub async fn search_hashes( image_remaining.0.to_string(), ) .header("content-type", "application/json") - .body(serde_json::to_string(&matches).unwrap()) + .body(serde_json::to_string(&results).unwrap()) .unwrap(); Ok(Box::new(resp)) @@ -427,10 +362,10 @@ pub async fn search_file( .map(|artist| vec![artist]), distance: None, hash: None, + searched_hash: None, site_info: Some(SiteInfo::FurAffinity { file_id: row.get("file_id"), }), - searched_hash: None, rating: row .get::, _>("rating") .and_then(|rating| rating.parse().ok()), @@ -452,8 +387,8 @@ pub async fn search_file( pub async fn search_video( form: warp::multipart::FormData, - db: Pool, - api_key: String, + _db: Pool, + _api_key: String, ) -> Result { let hashes = hash_video(form).await; @@ -535,7 +470,7 @@ pub async fn search_image_by_url( let hash: [u8; 8] = hash.as_bytes().try_into().unwrap(); let num = i64::from_be_bytes(hash); - let results = image_query(db.clone(), tree.clone(), vec![num], 3, Some(hash.to_vec())) + let results = image_query(db.clone(), tree.clone(), vec![num], 3) .await .unwrap(); diff --git a/fuzzysearch/src/main.rs b/fuzzysearch/src/main.rs index a4689c4..c762c42 100644 --- a/fuzzysearch/src/main.rs +++ b/fuzzysearch/src/main.rs @@ -171,26 +171,31 @@ async fn create_tree(conn: &Pool) -> bk_tree::BKTree { let mut tree = bk_tree::BKTree::new(Hamming); let mut rows = sqlx::query!( - "SELECT id, hash_int hash FROM submission WHERE hash_int IS NOT NULL - UNION ALL - SELECT id, hash FROM e621 WHERE hash IS NOT NULL - UNION ALL - SELECT tweet_id, hash FROM tweet_media WHERE hash IS NOT NULL - UNION ALL - SELECT id, hash FROM weasyl WHERE hash IS NOT NULL" + "SELECT hash_int hash FROM submission WHERE hash_int IS NOT NULL + UNION + SELECT hash FROM e621 WHERE hash IS NOT NULL + UNION + SELECT hash FROM tweet_media WHERE hash IS NOT NULL + UNION + SELECT hash FROM weasyl WHERE hash IS NOT NULL" ) .fetch(conn); + let mut count = 0; + while let Some(row) = rows.try_next().await.expect("Unable to get row") { if let Some(hash) = row.hash { - if tree.find_exact(&Node::new(hash)).is_some() { - continue; - } - tree.add(Node::new(hash)); + count += 1; + + if count % 250_000 == 0 { + tracing::debug!(count, "Made progress in loading tree rows"); + } } } + tracing::info!(count, "Completed loading rows for tree"); + tree } diff --git a/fuzzysearch/src/models.rs b/fuzzysearch/src/models.rs index 2f9b0de..106690b 100644 --- a/fuzzysearch/src/models.rs +++ b/fuzzysearch/src/models.rs @@ -1,18 +1,14 @@ -use std::collections::HashSet; - use lazy_static::lazy_static; use prometheus::{register_histogram, Histogram}; -use tracing_futures::Instrument; use crate::types::*; use crate::{Pool, Tree}; -use futures::TryStreamExt; use fuzzysearch_common::types::{SearchResult, SiteInfo}; lazy_static! { - static ref IMAGE_LOOKUP_DURATION: Histogram = register_histogram!( - "fuzzysearch_api_image_lookup_seconds", - "Duration to perform an image lookup" + static ref IMAGE_TREE_DURATION: Histogram = register_histogram!( + "fuzzysearch_api_image_tree_seconds", + "Duration to search for hashes in tree" ) .unwrap(); static ref IMAGE_QUERY_DURATION: Histogram = register_histogram!( @@ -48,153 +44,146 @@ pub async fn lookup_api_key(key: &str, db: &sqlx::PgPool) -> Option { .flatten() } +#[derive(serde::Serialize)] +struct HashSearch { + searched_hash: i64, + found_hash: i64, + distance: u64, +} + #[tracing::instrument(skip(pool, tree))] pub async fn image_query( pool: Pool, tree: Tree, hashes: Vec, distance: i64, - hash: Option>, ) -> Result, sqlx::Error> { - let mut results = image_query_sync(pool, tree, hashes, distance, hash); - let mut matches = Vec::new(); + let timer = IMAGE_TREE_DURATION.start_timer(); + let lock = tree.read().await; + let found_hashes: Vec = hashes + .iter() + .flat_map(|hash| { + lock.find(&crate::Node::new(*hash), distance as u64) + .map(|(dist, node)| HashSearch { + searched_hash: *hash, + found_hash: node.num(), + distance: dist, + }) + .collect::>() + }) + .collect(); + timer.stop_and_record(); - while let Some(r) = results.recv().await { - matches.extend(r?); - } + let timer = IMAGE_QUERY_DURATION.start_timer(); + let matches = sqlx::query!( + "WITH hashes AS ( + SELECT * FROM jsonb_to_recordset($1::jsonb) + AS hashes(searched_hash bigint, found_hash bigint, distance bigint) + ) + SELECT + 'FurAffinity' site, + submission.id, + submission.hash_int hash, + submission.url, + submission.filename, + ARRAY(SELECT artist.name) artists, + submission.file_id, + null sources, + submission.rating, + hashes.searched_hash, + hashes.distance + FROM hashes + JOIN submission ON hashes.found_hash = submission.hash_int + JOIN artist ON submission.artist_id = artist.id + WHERE hash_int IN (SELECT hashes.found_hash) + UNION ALL + SELECT + 'e621' site, + e621.id, + e621.hash, + e621.data->'file'->>'url' url, + (e621.data->'file'->>'md5') || '.' || (e621.data->'file'->>'ext') filename, + ARRAY(SELECT jsonb_array_elements_text(e621.data->'tags'->'artist')) artists, + null file_id, + ARRAY(SELECT jsonb_array_elements_text(e621.data->'sources')) sources, + e621.data->>'rating' rating, + hashes.searched_hash, + hashes.distance + FROM hashes + JOIN e621 ON hashes.found_hash = e621.hash + WHERE e621.hash IN (SELECT hashes.found_hash) + UNION ALL + SELECT + 'Weasyl' site, + weasyl.id, + weasyl.hash, + weasyl.data->>'link' url, + null filename, + ARRAY(SELECT weasyl.data->>'owner_login') artists, + null file_id, + null sources, + weasyl.data->>'rating' rating, + hashes.searched_hash, + hashes.distance + FROM hashes + JOIN weasyl ON hashes.found_hash = weasyl.hash + WHERE weasyl.hash IN (SELECT hashes.found_hash) + UNION ALL + SELECT + 'Twitter' site, + tweet.id, + tweet_media.hash, + tweet_media.url, + null filename, + ARRAY(SELECT tweet.data->'user'->>'screen_name') artists, + null file_id, + null sources, + CASE + WHEN (tweet.data->'possibly_sensitive')::boolean IS true THEN 'adult' + WHEN (tweet.data->'possibly_sensitive')::boolean IS false THEN 'general' + END rating, + hashes.searched_hash, + hashes.distance + FROM hashes + JOIN tweet_media ON hashes.found_hash = tweet_media.hash + JOIN tweet ON tweet_media.tweet_id = tweet.id + WHERE tweet_media.hash IN (SELECT hashes.found_hash)", + serde_json::to_value(&found_hashes).unwrap() + ) + .map(|row| { + use std::convert::TryFrom; + + let site_info = match row.site.as_deref() { + Some("FurAffinity") => SiteInfo::FurAffinity { + file_id: row.file_id.unwrap_or(-1), + }, + Some("e621") => SiteInfo::E621 { + sources: row.sources, + }, + Some("Twitter") => SiteInfo::Twitter, + Some("Weasyl") => SiteInfo::Weasyl, + _ => panic!("Got unknown site"), + }; + + SearchResult { + site_id: row.id.unwrap_or_default(), + site_info: Some(site_info), + rating: row.rating.and_then(|rating| rating.parse().ok()), + site_id_str: row.id.unwrap_or_default().to_string(), + url: row.url.unwrap_or_default(), + hash: row.hash, + distance: row + .distance + .map(|distance| u64::try_from(distance).ok()) + .flatten(), + artists: row.artists, + filename: row.filename.unwrap_or_default(), + searched_hash: row.searched_hash, + } + }) + .fetch_all(&pool) + .await?; + timer.stop_and_record(); Ok(matches) } - -#[tracing::instrument(skip(pool, tree))] -pub fn image_query_sync( - pool: Pool, - tree: Tree, - hashes: Vec, - distance: i64, - hash: Option>, -) -> tokio::sync::mpsc::Receiver, sqlx::Error>> { - let (tx, rx) = tokio::sync::mpsc::channel(50); - - tokio::spawn( - async move { - let db = pool; - - for query_hash in hashes { - tracing::trace!(query_hash, "Evaluating hash"); - - let mut seen: HashSet<[u8; 8]> = HashSet::new(); - - let _timer = IMAGE_LOOKUP_DURATION.start_timer(); - - let node = crate::Node::query(query_hash.to_be_bytes()); - let lock = tree.read().await; - let items = lock.find(&node, distance as u64); - - for (dist, item) in items { - if seen.contains(&item.0) { - tracing::trace!("Already searched for hash"); - continue; - } - seen.insert(item.0); - - let _timer = IMAGE_QUERY_DURATION.start_timer(); - - tracing::debug!(num = item.num(), "Searching database for hash in tree"); - - let mut row = sqlx::query!( - "SELECT - 'FurAffinity' site, - submission.id, - submission.hash_int hash, - submission.url, - submission.filename, - ARRAY(SELECT artist.name) artists, - submission.file_id, - null sources, - submission.rating - FROM submission - JOIN artist ON submission.artist_id = artist.id - WHERE hash_int <@ ($1, 0) - UNION ALL - SELECT - 'e621' site, - e621.id, - e621.hash, - e621.data->'file'->>'url' url, - (e621.data->'file'->>'md5') || '.' || (e621.data->'file'->>'ext') filename, - ARRAY(SELECT jsonb_array_elements_text(e621.data->'tags'->'artist')) artists, - null file_id, - ARRAY(SELECT jsonb_array_elements_text(e621.data->'sources')) sources, - e621.data->>'rating' rating - FROM e621 - WHERE hash <@ ($1, 0) - UNION ALL - SELECT - 'Weasyl' site, - weasyl.id, - weasyl.hash, - weasyl.data->>'link' url, - null filename, - ARRAY(SELECT weasyl.data->>'owner_login') artists, - null file_id, - null sources, - weasyl.data->>'rating' rating - FROM weasyl - WHERE hash <@ ($1, 0) - UNION ALL - SELECT - 'Twitter' site, - tweet.id, - tweet_media.hash, - tweet_media.url, - null filename, - ARRAY(SELECT tweet.data->'user'->>'screen_name') artists, - null file_id, - null sources, - CASE - WHEN (tweet.data->'possibly_sensitive')::boolean IS true THEN 'adult' - WHEN (tweet.data->'possibly_sensitive')::boolean IS false THEN 'general' - END rating - FROM tweet_media - JOIN tweet ON tweet_media.tweet_id = tweet.id - WHERE hash <@ ($1, 0)", - &item.num() - ) - .map(|row| { - let site_info = match row.site.as_deref() { - Some("FurAffinity") => SiteInfo::FurAffinity { file_id: row.file_id.unwrap_or(-1) }, - Some("e621") => SiteInfo::E621 { sources: row.sources }, - Some("Twitter") => SiteInfo::Twitter, - Some("Weasyl") => SiteInfo::Weasyl, - _ => panic!("Got unknown site"), - }; - - let file = SearchResult { - site_id: row.id.unwrap_or_default(), - site_info: Some(site_info), - rating: row.rating.and_then(|rating| rating.parse().ok()), - site_id_str: row.id.unwrap_or_default().to_string(), - url: row.url.unwrap_or_default(), - hash: row.hash, - distance: Some(dist), - artists: row.artists, - filename: row.filename.unwrap_or_default(), - searched_hash: Some(query_hash), - }; - - vec![file] - }) - .fetch(&db); - - while let Some(row) = row.try_next().await.ok().flatten() { - tx.send(Ok(row)).await.unwrap(); - } - } - } - } - .in_current_span(), - ); - - rx -} diff --git a/migrations/20210422224815_change_hash_index.down.sql b/migrations/20210422224815_change_hash_index.down.sql new file mode 100644 index 0000000..d271913 --- /dev/null +++ b/migrations/20210422224815_change_hash_index.down.sql @@ -0,0 +1,9 @@ +DROP INDEX submission_hash_int_idx; +DROP INDEX e621_hash_idx; +DROP INDEX tweet_media_hash_idx; +DROP INDEX weasyl_hash_idx; + +CREATE INDEX bk_furaffinity_hash ON submission USING spgist (hash_int bktree_ops); +CREATE INDEX bk_e621_hash ON e621 USING spgist (hash bktree_ops); +CREATE INDEX bk_twitter_hash ON tweet_media USING spgist (hash bktree_ops); +CREATE INDEX bk_weasyl_hash ON weasyl USING spgist (hash bktree_ops); diff --git a/migrations/20210422224815_change_hash_index.up.sql b/migrations/20210422224815_change_hash_index.up.sql new file mode 100644 index 0000000..1bb5559 --- /dev/null +++ b/migrations/20210422224815_change_hash_index.up.sql @@ -0,0 +1,9 @@ +DROP INDEX bk_furaffinity_hash; +DROP INDEX bk_e621_hash; +DROP INDEX bk_twitter_hash; +DROP INDEX bk_weasyl_hash; + +CREATE INDEX submission_hash_int_idx ON submission (hash_int); +CREATE INDEX e621_hash_idx ON e621 (hash); +CREATE INDEX tweet_media_hash_idx ON tweet_media (hash); +CREATE INDEX weasyl_hash_idx ON weasyl (hash);