Initial attempt at an in-memory tree.

This commit is contained in:
Syfaro 2020-02-15 23:50:09 -06:00
parent a68a46acf4
commit 904d3290e1
7 changed files with 260 additions and 112 deletions

View File

@ -8,7 +8,7 @@ platform:
arch: amd64 arch: amd64
steps: steps:
- name: docker - name: build-latest
image: plugins/docker image: plugins/docker
settings: settings:
auto_tag: true auto_tag: true
@ -18,9 +18,23 @@ steps:
repo: registry.huefox.com/fuzzysearch repo: registry.huefox.com/fuzzysearch
username: username:
from_secret: docker_username from_secret: docker_username
when:
trigger:
branch: branch:
- master - master
- name: build-branch
image: plugins/docker
settings:
password:
from_secret: docker_password
registry: registry.huefox.com
repo: registry.huefox.com/fuzzysearch
tags: ${DRONE_BRANCH}
username:
from_secret: docker_username
when:
branch:
exclude:
- master
... ...

7
Cargo.lock generated
View File

@ -131,6 +131,12 @@ version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693"
[[package]]
name = "bk-tree"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5488039ea2c6de8668351415e39a0218a8955bffadcff0cf01d1293a20854584"
[[package]] [[package]]
name = "block-buffer" name = "block-buffer"
version = "0.7.3" version = "0.7.3"
@ -501,6 +507,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"bb8", "bb8",
"bb8-postgres", "bb8-postgres",
"bk-tree",
"bytes 0.5.4", "bytes 0.5.4",
"chrono", "chrono",
"futures", "futures",

View File

@ -32,6 +32,8 @@ img_hash = "3"
image = "0.22" image = "0.22"
hamming = "0.1" hamming = "0.1"
bk-tree = "0.3"
[profile.release] [profile.release]
lto = true lto = true
codegen-units = 1 codegen-units = 1

View File

@ -1,12 +1,15 @@
use crate::types::*; use crate::types::*;
use crate::{handlers, Pool}; use crate::{handlers, Pool, Tree};
use std::convert::Infallible; use std::convert::Infallible;
use warp::{Filter, Rejection, Reply}; use warp::{Filter, Rejection, Reply};
pub fn search(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone { pub fn search(
search_image(db.clone()) db: Pool,
.or(search_hashes(db.clone())) tree: Tree,
.or(stream_search_image(db.clone())) ) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
search_image(db.clone(), tree.clone())
.or(search_hashes(db.clone(), tree.clone()))
.or(stream_search_image(db.clone(), tree))
.or(search_file(db)) .or(search_file(db))
} }
@ -20,35 +23,45 @@ pub fn search_file(db: Pool) -> impl Filter<Extract = impl Reply, Error = Reject
.and_then(handlers::search_file) .and_then(handlers::search_file)
} }
pub fn search_image(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone { pub fn search_image(
db: Pool,
tree: Tree,
) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
warp::path("image") warp::path("image")
.and(with_telem()) .and(with_telem())
.and(warp::post()) .and(warp::post())
.and(warp::multipart::form().max_length(1024 * 1024 * 10)) .and(warp::multipart::form().max_length(1024 * 1024 * 10))
.and(warp::query::<ImageSearchOpts>()) .and(warp::query::<ImageSearchOpts>())
.and(with_pool(db)) .and(with_pool(db))
.and(with_tree(tree))
.and(with_api_key()) .and(with_api_key())
.and_then(handlers::search_image) .and_then(handlers::search_image)
} }
pub fn search_hashes(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone { pub fn search_hashes(
db: Pool,
tree: Tree,
) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
warp::path("hashes") warp::path("hashes")
.and(with_telem()) .and(with_telem())
.and(warp::get()) .and(warp::get())
.and(warp::query::<HashSearchOpts>()) .and(warp::query::<HashSearchOpts>())
.and(with_pool(db)) .and(with_pool(db))
.and(with_tree(tree))
.and(with_api_key()) .and(with_api_key())
.and_then(handlers::search_hashes) .and_then(handlers::search_hashes)
} }
pub fn stream_search_image( pub fn stream_search_image(
db: Pool, db: Pool,
tree: Tree,
) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone { ) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
warp::path("stream") warp::path("stream")
.and(with_telem()) .and(with_telem())
.and(warp::post()) .and(warp::post())
.and(warp::multipart::form().max_length(1024 * 1024 * 10)) .and(warp::multipart::form().max_length(1024 * 1024 * 10))
.and(with_pool(db)) .and(with_pool(db))
.and(with_tree(tree))
.and(with_api_key()) .and(with_api_key())
.and_then(handlers::stream_image) .and_then(handlers::stream_image)
} }
@ -61,6 +74,10 @@ fn with_pool(db: Pool) -> impl Filter<Extract = (Pool,), Error = Infallible> + C
warp::any().map(move || db.clone()) warp::any().map(move || db.clone())
} }
fn with_tree(tree: Tree) -> impl Filter<Extract = (Tree,), Error = Infallible> + Clone {
warp::any().map(move || tree.clone())
}
fn with_telem() -> impl Filter<Extract = (crate::Span,), Error = Rejection> + Clone { fn with_telem() -> impl Filter<Extract = (crate::Span,), Error = Rejection> + Clone {
warp::any() warp::any()
.and(warp::header::optional("traceparent")) .and(warp::header::optional("traceparent"))
@ -75,7 +92,7 @@ fn with_telem() -> impl Filter<Extract = (crate::Span,), Error = Rejection> + Cl
tracing::trace!("got context from request: {:?}", context); tracing::trace!("got context from request: {:?}", context);
let span = if context.is_valid() { if context.is_valid() {
let tracer = opentelemetry::global::trace_provider().get_tracer("api"); let tracer = opentelemetry::global::trace_provider().get_tracer("api");
let span = tracer.start("context", Some(context)); let span = tracer.start("context", Some(context));
tracer.mark_span_as_active(&span); tracer.mark_span_as_active(&span);
@ -83,8 +100,6 @@ fn with_telem() -> impl Filter<Extract = (crate::Span,), Error = Rejection> + Cl
Some(span) Some(span)
} else { } else {
None None
}; }
span
}) })
} }

View File

@ -1,6 +1,6 @@
use crate::models::{image_query, image_query_sync}; use crate::models::{image_query, image_query_sync};
use crate::types::*; use crate::types::*;
use crate::{rate_limit, Pool}; use crate::{rate_limit, Pool, Tree};
use tracing::{span, warn}; use tracing::{span, warn};
use tracing_futures::Instrument; use tracing_futures::Instrument;
use warp::{reject, Rejection, Reply}; use warp::{reject, Rejection, Reply};
@ -76,12 +76,13 @@ async fn hash_input(form: warp::multipart::FormData) -> (i64, img_hash::ImageHas
(i64::from_be_bytes(buf), hash) (i64::from_be_bytes(buf), hash)
} }
#[tracing::instrument(skip(_telem, form, pool, api_key))] #[tracing::instrument(skip(_telem, form, pool, tree, api_key))]
pub async fn search_image( pub async fn search_image(
_telem: crate::Span, _telem: crate::Span,
form: warp::multipart::FormData, form: warp::multipart::FormData,
opts: ImageSearchOpts, opts: ImageSearchOpts,
pool: Pool, pool: Pool,
tree: Tree,
api_key: String, api_key: String,
) -> Result<impl Reply, Rejection> { ) -> Result<impl Reply, Rejection> {
let db = pool.get().await.map_err(map_bb8_err)?; let db = pool.get().await.map_err(map_bb8_err)?;
@ -92,15 +93,33 @@ pub async fn search_image(
let mut items = { let mut items = {
if opts.search_type == Some(ImageSearchType::Force) { if opts.search_type == Some(ImageSearchType::Force) {
image_query(pool.clone(), vec![num], 10, Some(hash.as_bytes().to_vec())) image_query(
pool.clone(),
tree.clone(),
vec![num],
10,
Some(hash.as_bytes().to_vec()),
)
.await .await
.unwrap() .unwrap()
} else { } else {
let results = image_query(pool.clone(), vec![num], 0, Some(hash.as_bytes().to_vec())) let results = image_query(
pool.clone(),
tree.clone(),
vec![num],
0,
Some(hash.as_bytes().to_vec()),
)
.await .await
.unwrap(); .unwrap();
if results.is_empty() && opts.search_type != Some(ImageSearchType::Exact) { if results.is_empty() && opts.search_type != Some(ImageSearchType::Exact) {
image_query(pool.clone(), vec![num], 10, Some(hash.as_bytes().to_vec())) image_query(
pool.clone(),
tree.clone(),
vec![num],
10,
Some(hash.as_bytes().to_vec()),
)
.await .await
.unwrap() .unwrap()
} else { } else {
@ -124,11 +143,12 @@ pub async fn search_image(
Ok(warp::reply::json(&similarity)) Ok(warp::reply::json(&similarity))
} }
#[tracing::instrument(skip(_telem, form, pool, api_key))] #[tracing::instrument(skip(_telem, form, pool, tree, api_key))]
pub async fn stream_image( pub async fn stream_image(
_telem: crate::Span, _telem: crate::Span,
form: warp::multipart::FormData, form: warp::multipart::FormData,
pool: Pool, pool: Pool,
tree: Tree,
api_key: String, api_key: String,
) -> Result<impl Reply, Rejection> { ) -> Result<impl Reply, Rejection> {
use futures_util::StreamExt; use futures_util::StreamExt;
@ -139,16 +159,15 @@ pub async fn stream_image(
let (num, hash) = hash_input(form).await; let (num, hash) = hash_input(form).await;
let exact_event_stream = let event_stream = image_query_sync(
image_query_sync(pool.clone(), vec![num], 0, Some(hash.as_bytes().to_vec())) pool.clone(),
tree,
vec![num],
10,
Some(hash.as_bytes().to_vec()),
)
.map(sse_matches); .map(sse_matches);
let close_event_stream =
image_query_sync(pool.clone(), vec![num], 10, Some(hash.as_bytes().to_vec()))
.map(sse_matches);
let event_stream = futures::stream::select(exact_event_stream, close_event_stream);
Ok(warp::sse::reply(event_stream)) Ok(warp::sse::reply(event_stream))
} }
@ -160,11 +179,12 @@ fn sse_matches(
Ok(warp::sse::json(items)) Ok(warp::sse::json(items))
} }
#[tracing::instrument(skip(_telem, form, db, api_key))] #[tracing::instrument(skip(_telem, form, db, tree, api_key))]
pub async fn search_hashes( pub async fn search_hashes(
_telem: crate::Span, _telem: crate::Span,
opts: HashSearchOpts, opts: HashSearchOpts,
db: Pool, db: Pool,
tree: Tree,
api_key: String, api_key: String,
) -> Result<impl Reply, Rejection> { ) -> Result<impl Reply, Rejection> {
let pool = db.clone(); let pool = db.clone();
@ -183,7 +203,7 @@ pub async fn search_hashes(
rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16); rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16);
let mut results = image_query_sync(pool, hashes.clone(), 10, None); let mut results = image_query_sync(pool, tree, hashes.clone(), 10, None);
let mut matches = Vec::new(); let mut matches = Vec::new();
while let Some(r) = results.recv().await { while let Some(r) = results.recv().await {

View File

@ -1,6 +1,8 @@
#![recursion_limit = "256"] #![recursion_limit = "256"]
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::RwLock;
mod filters; mod filters;
mod handlers; mod handlers;
@ -60,6 +62,28 @@ fn configure_tracing() {
.expect("Unable to set default tracing subscriber"); .expect("Unable to set default tracing subscriber");
} }
#[derive(Debug)]
pub struct Node {
id: i32,
hash: [u8; 8],
}
impl Node {
pub fn query(hash: [u8; 8]) -> Self {
Self { id: -1, hash }
}
}
type Tree = Arc<RwLock<bk_tree::BKTree<Node, Hamming>>>;
pub struct Hamming;
impl bk_tree::Metric<Node> for Hamming {
fn distance(&self, a: &Node, b: &Node) -> u64 {
hamming::distance_fast(&a.hash, &b.hash).unwrap()
}
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
pretty_env_logger::init(); pretty_env_logger::init();
@ -78,6 +102,76 @@ async fn main() {
.await .await
.expect("Unable to build Postgres pool"); .expect("Unable to build Postgres pool");
let tree: Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(Hamming)));
let mut max_id = 0;
let conn = db_pool.get().await.unwrap();
let mut lock = tree.write().await;
conn.query("SELECT id, hash FROM hashes", &[])
.await
.unwrap()
.into_iter()
.for_each(|row| {
let id: i32 = row.get(0);
let hash: i64 = row.get(1);
let bytes = hash.to_be_bytes();
if id > max_id {
max_id = id;
}
lock.add(Node { id, hash: bytes });
});
drop(lock);
drop(conn);
let tree_clone = tree.clone();
let pool_clone = db_pool.clone();
tokio::spawn(async move {
use futures_util::StreamExt;
let max_id = std::sync::atomic::AtomicI32::new(max_id);
let tree = tree_clone;
let pool = pool_clone;
let order = std::sync::atomic::Ordering::SeqCst;
let interval = tokio::time::interval(std::time::Duration::from_secs(30));
interval
.for_each(|_| async {
tracing::debug!("Refreshing hashes");
let conn = pool.get().await.unwrap();
let mut lock = tree.write().await;
let id = max_id.load(order);
let mut count = 0;
conn.query("SELECT id, hash FROM hashes WHERE hashes.id > $1", &[&id])
.await
.unwrap()
.into_iter()
.for_each(|row| {
let id: i32 = row.get(0);
let hash: i64 = row.get(1);
let bytes = hash.to_be_bytes();
if id > max_id.load(order) {
max_id.store(id, order);
}
lock.add(Node { id, hash: bytes });
count += 1;
});
tracing::trace!("Added {} new hashes", count);
})
.await;
});
let log = warp::log("fuzzysearch"); let log = warp::log("fuzzysearch");
let cors = warp::cors() let cors = warp::cors()
.allow_any_origin() .allow_any_origin()
@ -86,7 +180,7 @@ async fn main() {
let options = warp::options().map(|| ""); let options = warp::options().map(|| "");
let api = options.or(filters::search(db_pool)); let api = options.or(filters::search(db_pool, tree));
let routes = api let routes = api
.or(warp::path::end() .or(warp::path::end()
.map(|| warp::redirect(warp::http::Uri::from_static("https://fuzzysearch.net")))) .map(|| warp::redirect(warp::http::Uri::from_static("https://fuzzysearch.net"))))

View File

@ -1,6 +1,6 @@
use crate::types::*; use crate::types::*;
use crate::utils::extract_rows; use crate::utils::extract_rows;
use crate::Pool; use crate::{Pool, Tree};
use tracing_futures::Instrument; use tracing_futures::Instrument;
pub type DB<'a> = pub type DB<'a> =
@ -39,14 +39,15 @@ pub async fn lookup_api_key(key: &str, db: DB<'_>) -> Option<ApiKey> {
} }
} }
#[tracing::instrument(skip(pool))] #[tracing::instrument(skip(pool, tree))]
pub async fn image_query( pub async fn image_query(
pool: Pool, pool: Pool,
tree: Tree,
hashes: Vec<i64>, hashes: Vec<i64>,
distance: i64, distance: i64,
hash: Option<Vec<u8>>, hash: Option<Vec<u8>>,
) -> Result<Vec<File>, tokio_postgres::Error> { ) -> Result<Vec<File>, tokio_postgres::Error> {
let mut results = image_query_sync(pool, hashes, distance, hash); let mut results = image_query_sync(pool, tree, hashes, distance, hash);
let mut matches = Vec::new(); let mut matches = Vec::new();
while let Some(r) = results.recv().await { while let Some(r) = results.recv().await {
@ -56,31 +57,26 @@ pub async fn image_query(
Ok(matches) Ok(matches)
} }
#[tracing::instrument(skip(pool))] #[tracing::instrument(skip(pool, tree))]
pub fn image_query_sync( pub fn image_query_sync(
pool: Pool, pool: Pool,
tree: Tree,
hashes: Vec<i64>, hashes: Vec<i64>,
distance: i64, distance: i64,
hash: Option<Vec<u8>>, hash: Option<Vec<u8>>,
) -> tokio::sync::mpsc::Receiver<Result<Vec<File>, tokio_postgres::Error>> { ) -> tokio::sync::mpsc::Receiver<Result<Vec<File>, tokio_postgres::Error>> {
let (mut tx, rx) = tokio::sync::mpsc::channel(1); let (mut tx, rx) = tokio::sync::mpsc::channel(50);
tokio::spawn(async move { tokio::spawn(async move {
let db = pool.get().await.unwrap(); let db = pool.get().await.unwrap();
let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = for query_hash in hashes {
Vec::with_capacity(hashes.len() + 1); let node = crate::Node::query(query_hash.to_be_bytes());
params.insert(0, &distance); let lock = tree.read().await;
let items = lock.find(&node, distance as u64);
let mut hash_where_clause = Vec::with_capacity(hashes.len()); for (_dist, item) in items {
for (idx, hash) in hashes.iter().enumerate() { let query = db.query("SELECT
params.push(hash);
hash_where_clause.push(format!(" hashes.hash <@ (${}, $1)", idx + 2));
}
let hash_where_clause = hash_where_clause.join(" OR ");
let hash_query = format!(
"SELECT
hashes.id, hashes.id,
hashes.hash, hashes.hash,
hashes.furaffinity_id, hashes.furaffinity_id,
@ -133,11 +129,11 @@ pub fn image_query_sync(
tweet_media.hash <@ (hashes.hash, 0) tweet_media.hash <@ (hashes.hash, 0)
LIMIT 1 LIMIT 1
) tm ON hashes.twitter_id IS NOT NULL ) tm ON hashes.twitter_id IS NOT NULL
WHERE {}", hash_where_clause); WHERE hashes.id = $1", &[&item.id]).await;
let query = db.query::<str>(&*hash_query, &params).await;
let rows = query.map(|rows| extract_rows(rows, hash.as_deref()).into_iter().collect()); let rows = query.map(|rows| extract_rows(rows, hash.as_deref()).into_iter().collect());
tx.send(rows).await.unwrap(); tx.send(rows).await.unwrap();
}
}
}.in_current_span()); }.in_current_span());
rx rx