From 904d3290e10ce0242ed773a611056147680c1d97 Mon Sep 17 00:00:00 2001 From: Syfaro Date: Sat, 15 Feb 2020 23:50:09 -0600 Subject: [PATCH] Initial attempt at an in-memory tree. --- .drone.yml | 22 ++++++-- Cargo.lock | 7 +++ Cargo.toml | 2 + src/filters.rs | 37 +++++++++---- src/handlers.rs | 66 ++++++++++++++-------- src/main.rs | 96 +++++++++++++++++++++++++++++++- src/models.rs | 142 +++++++++++++++++++++++------------------------- 7 files changed, 260 insertions(+), 112 deletions(-) diff --git a/.drone.yml b/.drone.yml index 805e343..ba1f691 100644 --- a/.drone.yml +++ b/.drone.yml @@ -8,7 +8,7 @@ platform: arch: amd64 steps: -- name: docker +- name: build-latest image: plugins/docker settings: auto_tag: true @@ -18,9 +18,23 @@ steps: repo: registry.huefox.com/fuzzysearch username: from_secret: docker_username + when: + branch: + - master -trigger: - branch: - - master +- name: build-branch + image: plugins/docker + settings: + password: + from_secret: docker_password + registry: registry.huefox.com + repo: registry.huefox.com/fuzzysearch + tags: ${DRONE_BRANCH} + username: + from_secret: docker_username + when: + branch: + exclude: + - master ... diff --git a/Cargo.lock b/Cargo.lock index 1cc2efd..0b0acc3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -131,6 +131,12 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" +[[package]] +name = "bk-tree" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5488039ea2c6de8668351415e39a0218a8955bffadcff0cf01d1293a20854584" + [[package]] name = "block-buffer" version = "0.7.3" @@ -501,6 +507,7 @@ version = "0.1.0" dependencies = [ "bb8", "bb8-postgres", + "bk-tree", "bytes 0.5.4", "chrono", "futures", diff --git a/Cargo.toml b/Cargo.toml index c158649..13c6293 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,8 @@ img_hash = "3" image = "0.22" hamming = "0.1" +bk-tree = "0.3" + [profile.release] lto = true codegen-units = 1 diff --git a/src/filters.rs b/src/filters.rs index 9d4d8f0..1cd4e6e 100644 --- a/src/filters.rs +++ b/src/filters.rs @@ -1,12 +1,15 @@ use crate::types::*; -use crate::{handlers, Pool}; +use crate::{handlers, Pool, Tree}; use std::convert::Infallible; use warp::{Filter, Rejection, Reply}; -pub fn search(db: Pool) -> impl Filter + Clone { - search_image(db.clone()) - .or(search_hashes(db.clone())) - .or(stream_search_image(db.clone())) +pub fn search( + db: Pool, + tree: Tree, +) -> impl Filter + Clone { + search_image(db.clone(), tree.clone()) + .or(search_hashes(db.clone(), tree.clone())) + .or(stream_search_image(db.clone(), tree)) .or(search_file(db)) } @@ -20,35 +23,45 @@ pub fn search_file(db: Pool) -> impl Filter impl Filter + Clone { +pub fn search_image( + db: Pool, + tree: Tree, +) -> impl Filter + Clone { warp::path("image") .and(with_telem()) .and(warp::post()) .and(warp::multipart::form().max_length(1024 * 1024 * 10)) .and(warp::query::()) .and(with_pool(db)) + .and(with_tree(tree)) .and(with_api_key()) .and_then(handlers::search_image) } -pub fn search_hashes(db: Pool) -> impl Filter + Clone { +pub fn search_hashes( + db: Pool, + tree: Tree, +) -> impl Filter + Clone { warp::path("hashes") .and(with_telem()) .and(warp::get()) .and(warp::query::()) .and(with_pool(db)) + .and(with_tree(tree)) .and(with_api_key()) .and_then(handlers::search_hashes) } pub fn stream_search_image( db: Pool, + tree: Tree, ) -> impl Filter + Clone { warp::path("stream") .and(with_telem()) .and(warp::post()) .and(warp::multipart::form().max_length(1024 * 1024 * 10)) .and(with_pool(db)) + .and(with_tree(tree)) .and(with_api_key()) .and_then(handlers::stream_image) } @@ -61,6 +74,10 @@ fn with_pool(db: Pool) -> impl Filter + C warp::any().map(move || db.clone()) } +fn with_tree(tree: Tree) -> impl Filter + Clone { + warp::any().map(move || tree.clone()) +} + fn with_telem() -> impl Filter + Clone { warp::any() .and(warp::header::optional("traceparent")) @@ -75,7 +92,7 @@ fn with_telem() -> impl Filter + Cl tracing::trace!("got context from request: {:?}", context); - let span = if context.is_valid() { + if context.is_valid() { let tracer = opentelemetry::global::trace_provider().get_tracer("api"); let span = tracer.start("context", Some(context)); tracer.mark_span_as_active(&span); @@ -83,8 +100,6 @@ fn with_telem() -> impl Filter + Cl Some(span) } else { None - }; - - span + } }) } diff --git a/src/handlers.rs b/src/handlers.rs index 789a79c..bb2a17c 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -1,6 +1,6 @@ use crate::models::{image_query, image_query_sync}; use crate::types::*; -use crate::{rate_limit, Pool}; +use crate::{rate_limit, Pool, Tree}; use tracing::{span, warn}; use tracing_futures::Instrument; use warp::{reject, Rejection, Reply}; @@ -76,12 +76,13 @@ async fn hash_input(form: warp::multipart::FormData) -> (i64, img_hash::ImageHas (i64::from_be_bytes(buf), hash) } -#[tracing::instrument(skip(_telem, form, pool, api_key))] +#[tracing::instrument(skip(_telem, form, pool, tree, api_key))] pub async fn search_image( _telem: crate::Span, form: warp::multipart::FormData, opts: ImageSearchOpts, pool: Pool, + tree: Tree, api_key: String, ) -> Result { let db = pool.get().await.map_err(map_bb8_err)?; @@ -92,17 +93,35 @@ pub async fn search_image( let mut items = { if opts.search_type == Some(ImageSearchType::Force) { - image_query(pool.clone(), vec![num], 10, Some(hash.as_bytes().to_vec())) + image_query( + pool.clone(), + tree.clone(), + vec![num], + 10, + Some(hash.as_bytes().to_vec()), + ) + .await + .unwrap() + } else { + let results = image_query( + pool.clone(), + tree.clone(), + vec![num], + 0, + Some(hash.as_bytes().to_vec()), + ) + .await + .unwrap(); + if results.is_empty() && opts.search_type != Some(ImageSearchType::Exact) { + image_query( + pool.clone(), + tree.clone(), + vec![num], + 10, + Some(hash.as_bytes().to_vec()), + ) .await .unwrap() - } else { - let results = image_query(pool.clone(), vec![num], 0, Some(hash.as_bytes().to_vec())) - .await - .unwrap(); - if results.is_empty() && opts.search_type != Some(ImageSearchType::Exact) { - image_query(pool.clone(), vec![num], 10, Some(hash.as_bytes().to_vec())) - .await - .unwrap() } else { results } @@ -124,11 +143,12 @@ pub async fn search_image( Ok(warp::reply::json(&similarity)) } -#[tracing::instrument(skip(_telem, form, pool, api_key))] +#[tracing::instrument(skip(_telem, form, pool, tree, api_key))] pub async fn stream_image( _telem: crate::Span, form: warp::multipart::FormData, pool: Pool, + tree: Tree, api_key: String, ) -> Result { use futures_util::StreamExt; @@ -139,15 +159,14 @@ pub async fn stream_image( let (num, hash) = hash_input(form).await; - let exact_event_stream = - image_query_sync(pool.clone(), vec![num], 0, Some(hash.as_bytes().to_vec())) - .map(sse_matches); - - let close_event_stream = - image_query_sync(pool.clone(), vec![num], 10, Some(hash.as_bytes().to_vec())) - .map(sse_matches); - - let event_stream = futures::stream::select(exact_event_stream, close_event_stream); + let event_stream = image_query_sync( + pool.clone(), + tree, + vec![num], + 10, + Some(hash.as_bytes().to_vec()), + ) + .map(sse_matches); Ok(warp::sse::reply(event_stream)) } @@ -160,11 +179,12 @@ fn sse_matches( Ok(warp::sse::json(items)) } -#[tracing::instrument(skip(_telem, form, db, api_key))] +#[tracing::instrument(skip(_telem, form, db, tree, api_key))] pub async fn search_hashes( _telem: crate::Span, opts: HashSearchOpts, db: Pool, + tree: Tree, api_key: String, ) -> Result { let pool = db.clone(); @@ -183,7 +203,7 @@ pub async fn search_hashes( rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16); - let mut results = image_query_sync(pool, hashes.clone(), 10, None); + let mut results = image_query_sync(pool, tree, hashes.clone(), 10, None); let mut matches = Vec::new(); while let Some(r) = results.recv().await { diff --git a/src/main.rs b/src/main.rs index 1b62a72..fb87105 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,8 @@ #![recursion_limit = "256"] use std::str::FromStr; +use std::sync::Arc; +use tokio::sync::RwLock; mod filters; mod handlers; @@ -60,6 +62,28 @@ fn configure_tracing() { .expect("Unable to set default tracing subscriber"); } +#[derive(Debug)] +pub struct Node { + id: i32, + hash: [u8; 8], +} + +impl Node { + pub fn query(hash: [u8; 8]) -> Self { + Self { id: -1, hash } + } +} + +type Tree = Arc>>; + +pub struct Hamming; + +impl bk_tree::Metric for Hamming { + fn distance(&self, a: &Node, b: &Node) -> u64 { + hamming::distance_fast(&a.hash, &b.hash).unwrap() + } +} + #[tokio::main] async fn main() { pretty_env_logger::init(); @@ -78,6 +102,76 @@ async fn main() { .await .expect("Unable to build Postgres pool"); + let tree: Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(Hamming))); + + let mut max_id = 0; + + let conn = db_pool.get().await.unwrap(); + let mut lock = tree.write().await; + conn.query("SELECT id, hash FROM hashes", &[]) + .await + .unwrap() + .into_iter() + .for_each(|row| { + let id: i32 = row.get(0); + let hash: i64 = row.get(1); + let bytes = hash.to_be_bytes(); + + if id > max_id { + max_id = id; + } + + lock.add(Node { id, hash: bytes }); + }); + drop(lock); + drop(conn); + + let tree_clone = tree.clone(); + let pool_clone = db_pool.clone(); + tokio::spawn(async move { + use futures_util::StreamExt; + + let max_id = std::sync::atomic::AtomicI32::new(max_id); + let tree = tree_clone; + let pool = pool_clone; + + let order = std::sync::atomic::Ordering::SeqCst; + + let interval = tokio::time::interval(std::time::Duration::from_secs(30)); + + interval + .for_each(|_| async { + tracing::debug!("Refreshing hashes"); + + let conn = pool.get().await.unwrap(); + let mut lock = tree.write().await; + let id = max_id.load(order); + + let mut count = 0; + + conn.query("SELECT id, hash FROM hashes WHERE hashes.id > $1", &[&id]) + .await + .unwrap() + .into_iter() + .for_each(|row| { + let id: i32 = row.get(0); + let hash: i64 = row.get(1); + let bytes = hash.to_be_bytes(); + + if id > max_id.load(order) { + max_id.store(id, order); + } + + lock.add(Node { id, hash: bytes }); + + count += 1; + }); + + tracing::trace!("Added {} new hashes", count); + }) + .await; + }); + let log = warp::log("fuzzysearch"); let cors = warp::cors() .allow_any_origin() @@ -86,7 +180,7 @@ async fn main() { let options = warp::options().map(|| "✓"); - let api = options.or(filters::search(db_pool)); + let api = options.or(filters::search(db_pool, tree)); let routes = api .or(warp::path::end() .map(|| warp::redirect(warp::http::Uri::from_static("https://fuzzysearch.net")))) diff --git a/src/models.rs b/src/models.rs index 1864908..0be59b8 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,6 +1,6 @@ use crate::types::*; use crate::utils::extract_rows; -use crate::Pool; +use crate::{Pool, Tree}; use tracing_futures::Instrument; pub type DB<'a> = @@ -39,14 +39,15 @@ pub async fn lookup_api_key(key: &str, db: DB<'_>) -> Option { } } -#[tracing::instrument(skip(pool))] +#[tracing::instrument(skip(pool, tree))] pub async fn image_query( pool: Pool, + tree: Tree, hashes: Vec, distance: i64, hash: Option>, ) -> Result, tokio_postgres::Error> { - let mut results = image_query_sync(pool, hashes, distance, hash); + let mut results = image_query_sync(pool, tree, hashes, distance, hash); let mut matches = Vec::new(); while let Some(r) = results.recv().await { @@ -56,88 +57,83 @@ pub async fn image_query( Ok(matches) } -#[tracing::instrument(skip(pool))] +#[tracing::instrument(skip(pool, tree))] pub fn image_query_sync( pool: Pool, + tree: Tree, hashes: Vec, distance: i64, hash: Option>, ) -> tokio::sync::mpsc::Receiver, tokio_postgres::Error>> { - let (mut tx, rx) = tokio::sync::mpsc::channel(1); + let (mut tx, rx) = tokio::sync::mpsc::channel(50); tokio::spawn(async move { let db = pool.get().await.unwrap(); - let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = - Vec::with_capacity(hashes.len() + 1); - params.insert(0, &distance); + for query_hash in hashes { + let node = crate::Node::query(query_hash.to_be_bytes()); + let lock = tree.read().await; + let items = lock.find(&node, distance as u64); - let mut hash_where_clause = Vec::with_capacity(hashes.len()); - for (idx, hash) in hashes.iter().enumerate() { - params.push(hash); - hash_where_clause.push(format!(" hashes.hash <@ (${}, $1)", idx + 2)); + for (_dist, item) in items { + let query = db.query("SELECT + hashes.id, + hashes.hash, + hashes.furaffinity_id, + hashes.e621_id, + hashes.twitter_id, + CASE + WHEN furaffinity_id IS NOT NULL THEN (f.url) + WHEN e621_id IS NOT NULL THEN (e.data->>'file_url') + WHEN twitter_id IS NOT NULL THEN (tm.url) + END url, + CASE + WHEN furaffinity_id IS NOT NULL THEN (f.filename) + WHEN e621_id IS NOT NULL THEN ((e.data->>'md5') || '.' || (e.data->>'file_ext')) + WHEN twitter_id IS NOT NULL THEN (SELECT split_part(split_part(tm.url, '/', 5), ':', 1)) + END filename, + CASE + WHEN furaffinity_id IS NOT NULL THEN (ARRAY(SELECT f.name)) + WHEN e621_id IS NOT NULL THEN ARRAY(SELECT jsonb_array_elements_text(e.data->'artist')) + WHEN twitter_id IS NOT NULL THEN ARRAY(SELECT tw.data->'user'->>'screen_name') + END artists, + CASE + WHEN furaffinity_id IS NOT NULL THEN (f.file_id) + END file_id, + CASE + WHEN e621_id IS NOT NULL THEN ARRAY(SELECT jsonb_array_elements_text(e.data->'sources')) + END sources + FROM + hashes + LEFT JOIN LATERAL ( + SELECT * + FROM submission + JOIN artist ON submission.artist_id = artist.id + WHERE submission.id = hashes.furaffinity_id + ) f ON hashes.furaffinity_id IS NOT NULL + LEFT JOIN LATERAL ( + SELECT * + FROM e621 + WHERE e621.id = hashes.e621_id + ) e ON hashes.e621_id IS NOT NULL + LEFT JOIN LATERAL ( + SELECT * + FROM tweet + WHERE tweet.id = hashes.twitter_id + ) tw ON hashes.twitter_id IS NOT NULL + LEFT JOIN LATERAL ( + SELECT * + FROM tweet_media + WHERE + tweet_media.tweet_id = hashes.twitter_id AND + tweet_media.hash <@ (hashes.hash, 0) + LIMIT 1 + ) tm ON hashes.twitter_id IS NOT NULL + WHERE hashes.id = $1", &[&item.id]).await; + let rows = query.map(|rows| extract_rows(rows, hash.as_deref()).into_iter().collect()); + tx.send(rows).await.unwrap(); + } } - let hash_where_clause = hash_where_clause.join(" OR "); - - let hash_query = format!( - "SELECT - hashes.id, - hashes.hash, - hashes.furaffinity_id, - hashes.e621_id, - hashes.twitter_id, - CASE - WHEN furaffinity_id IS NOT NULL THEN (f.url) - WHEN e621_id IS NOT NULL THEN (e.data->>'file_url') - WHEN twitter_id IS NOT NULL THEN (tm.url) - END url, - CASE - WHEN furaffinity_id IS NOT NULL THEN (f.filename) - WHEN e621_id IS NOT NULL THEN ((e.data->>'md5') || '.' || (e.data->>'file_ext')) - WHEN twitter_id IS NOT NULL THEN (SELECT split_part(split_part(tm.url, '/', 5), ':', 1)) - END filename, - CASE - WHEN furaffinity_id IS NOT NULL THEN (ARRAY(SELECT f.name)) - WHEN e621_id IS NOT NULL THEN ARRAY(SELECT jsonb_array_elements_text(e.data->'artist')) - WHEN twitter_id IS NOT NULL THEN ARRAY(SELECT tw.data->'user'->>'screen_name') - END artists, - CASE - WHEN furaffinity_id IS NOT NULL THEN (f.file_id) - END file_id, - CASE - WHEN e621_id IS NOT NULL THEN ARRAY(SELECT jsonb_array_elements_text(e.data->'sources')) - END sources - FROM - hashes - LEFT JOIN LATERAL ( - SELECT * - FROM submission - JOIN artist ON submission.artist_id = artist.id - WHERE submission.id = hashes.furaffinity_id - ) f ON hashes.furaffinity_id IS NOT NULL - LEFT JOIN LATERAL ( - SELECT * - FROM e621 - WHERE e621.id = hashes.e621_id - ) e ON hashes.e621_id IS NOT NULL - LEFT JOIN LATERAL ( - SELECT * - FROM tweet - WHERE tweet.id = hashes.twitter_id - ) tw ON hashes.twitter_id IS NOT NULL - LEFT JOIN LATERAL ( - SELECT * - FROM tweet_media - WHERE - tweet_media.tweet_id = hashes.twitter_id AND - tweet_media.hash <@ (hashes.hash, 0) - LIMIT 1 - ) tm ON hashes.twitter_id IS NOT NULL - WHERE {}", hash_where_clause); - - let query = db.query::(&*hash_query, ¶ms).await; - let rows = query.map(|rows| extract_rows(rows, hash.as_deref()).into_iter().collect()); - tx.send(rows).await.unwrap(); }.in_current_span()); rx