mirror of
https://github.com/Syfaro/fuzzysearch.git
synced 2024-11-05 14:32:56 +00:00
Use NOTIFY/LISTEN instead of polling for updates (#3)
* Use NOTIFY/LISTEN instead of polling for updates. * Allow different distances for multiple hashes.
This commit is contained in:
parent
06a1c7b466
commit
908cda8ce9
1
.gitignore
vendored
1
.gitignore
vendored
@ -1 +1,2 @@
|
||||
/target
|
||||
.env
|
||||
|
678
Cargo.lock
generated
678
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -28,9 +28,7 @@ warp = "0.3"
|
||||
reqwest = "0.11"
|
||||
hyper = "0.14"
|
||||
|
||||
tokio-postgres = "0.7"
|
||||
bb8 = "0.7"
|
||||
bb8-postgres = "0.7"
|
||||
sqlx = { version = "0.5", features = ["runtime-tokio-native-tls", "postgres", "macros", "json", "offline"] }
|
||||
|
||||
img_hash = "3"
|
||||
image = "0.23"
|
||||
|
@ -1,5 +1,6 @@
|
||||
FROM rust:1-slim AS builder
|
||||
WORKDIR /src
|
||||
ENV SQLX_OFFLINE=true
|
||||
RUN apt-get update -y && apt-get install -y libssl-dev pkg-config
|
||||
COPY . .
|
||||
RUN cargo install --root / --path .
|
||||
|
194
sqlx-data.json
Normal file
194
sqlx-data.json
Normal file
@ -0,0 +1,194 @@
|
||||
{
|
||||
"db": "PostgreSQL",
|
||||
"1984ce60f052d6a29638f8e05b35671b8edfbf273783d4b843ebd35cbb8a391f": {
|
||||
"query": "INSERT INTO\n rate_limit (api_key_id, time_window, group_name, count)\n VALUES\n ($1, $2, $3, $4)\n ON CONFLICT ON CONSTRAINT unique_window\n DO UPDATE set count = rate_limit.count + $4\n RETURNING rate_limit.count",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "count",
|
||||
"type_info": "Int2"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int4",
|
||||
"Int8",
|
||||
"Text",
|
||||
"Int2"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"659ee9ddc1c5ccd42ba9dc1617440544c30ece449ba3ba7f9d39f447b8af3cfe": {
|
||||
"query": "SELECT\n api_key.id,\n api_key.name_limit,\n api_key.image_limit,\n api_key.hash_limit,\n api_key.name,\n account.email owner_email\n FROM\n api_key\n JOIN account\n ON account.id = api_key.user_id\n WHERE\n api_key.key = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "name_limit",
|
||||
"type_info": "Int2"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "image_limit",
|
||||
"type_info": "Int2"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "hash_limit",
|
||||
"type_info": "Int2"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "name",
|
||||
"type_info": "Varchar"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "owner_email",
|
||||
"type_info": "Varchar"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"6b8d304fc40fa539ae671e6e24e7978ad271cb7a1cafb20fc4b4096a958d790f": {
|
||||
"query": "SELECT exists(SELECT 1 FROM twitter_user WHERE lower(data->>'screen_name') = lower($1))",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "exists",
|
||||
"type_info": "Bool"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
null
|
||||
]
|
||||
}
|
||||
},
|
||||
"f4608ccaf739d36649cdbc5297177a989cc7763006d28c97e219bb708930972a": {
|
||||
"query": "SELECT\n hashes.id,\n hashes.hash,\n hashes.furaffinity_id,\n hashes.e621_id,\n hashes.twitter_id,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (f.url)\n WHEN e621_id IS NOT NULL THEN (e.data->'file'->>'url')\n WHEN twitter_id IS NOT NULL THEN (tm.url)\n END url,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (f.filename)\n WHEN e621_id IS NOT NULL THEN ((e.data->'file'->>'md5') || '.' || (e.data->'file'->>'ext'))\n WHEN twitter_id IS NOT NULL THEN (SELECT split_part(split_part(tm.url, '/', 5), ':', 1))\n END filename,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (ARRAY(SELECT f.name))\n WHEN e621_id IS NOT NULL THEN ARRAY(SELECT jsonb_array_elements_text(e.data->'tags'->'artist'))\n WHEN twitter_id IS NOT NULL THEN ARRAY(SELECT tw.data->'user'->>'screen_name')\n END artists,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (f.file_id)\n END file_id,\n CASE\n WHEN e621_id IS NOT NULL THEN ARRAY(SELECT jsonb_array_elements_text(e.data->'sources'))\n END sources\n FROM\n hashes\n LEFT JOIN LATERAL (\n SELECT *\n FROM submission\n JOIN artist ON submission.artist_id = artist.id\n WHERE submission.id = hashes.furaffinity_id\n ) f ON hashes.furaffinity_id IS NOT NULL\n LEFT JOIN LATERAL (\n SELECT *\n FROM e621\n WHERE e621.id = hashes.e621_id\n ) e ON hashes.e621_id IS NOT NULL\n LEFT JOIN LATERAL (\n SELECT *\n FROM tweet\n WHERE tweet.id = hashes.twitter_id\n ) tw ON hashes.twitter_id IS NOT NULL\n LEFT JOIN LATERAL (\n SELECT *\n FROM tweet_media\n WHERE\n tweet_media.tweet_id = hashes.twitter_id AND\n tweet_media.hash <@ (hashes.hash, 0)\n LIMIT 1\n ) tm ON hashes.twitter_id IS NOT NULL\n WHERE hashes.id = $1",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "hash",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "furaffinity_id",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "e621_id",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "twitter_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "url",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "filename",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 7,
|
||||
"name": "artists",
|
||||
"type_info": "TextArray"
|
||||
},
|
||||
{
|
||||
"ordinal": 8,
|
||||
"name": "file_id",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 9,
|
||||
"name": "sources",
|
||||
"type_info": "TextArray"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int4"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
]
|
||||
}
|
||||
},
|
||||
"fe60be66b2d8a8f02b3bfe06d1f0e57e4bb07e80cba1b379a5f17f6cbd8b075c": {
|
||||
"query": "SELECT id, hash FROM hashes",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "hash",
|
||||
"type_info": "Int8"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": []
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
201
src/handlers.rs
201
src/handlers.rs
@ -8,8 +8,7 @@ use warp::{Rejection, Reply};
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Error {
|
||||
Bb8(bb8::RunError<tokio_postgres::Error>),
|
||||
Postgres(tokio_postgres::Error),
|
||||
Postgres(sqlx::Error),
|
||||
Reqwest(reqwest::Error),
|
||||
InvalidData,
|
||||
InvalidImage,
|
||||
@ -20,7 +19,7 @@ enum Error {
|
||||
impl warp::Reply for Error {
|
||||
fn into_response(self) -> warp::reply::Response {
|
||||
let msg = match self {
|
||||
Error::Bb8(_) | Error::Postgres(_) | Error::Reqwest(_) => ErrorMessage {
|
||||
Error::Postgres(_) | Error::Reqwest(_) => ErrorMessage {
|
||||
code: 500,
|
||||
message: "Internal server error".to_string(),
|
||||
},
|
||||
@ -51,14 +50,8 @@ impl warp::Reply for Error {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<bb8::RunError<tokio_postgres::Error>> for Error {
|
||||
fn from(err: bb8::RunError<tokio_postgres::Error>) -> Self {
|
||||
Error::Bb8(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tokio_postgres::Error> for Error {
|
||||
fn from(err: tokio_postgres::Error) -> Self {
|
||||
impl From<sqlx::Error> for Error {
|
||||
fn from(err: sqlx::Error) -> Self {
|
||||
Error::Postgres(err)
|
||||
}
|
||||
}
|
||||
@ -112,12 +105,10 @@ async fn hash_input(form: warp::multipart::FormData) -> (i64, img_hash::ImageHas
|
||||
pub async fn search_image(
|
||||
form: warp::multipart::FormData,
|
||||
opts: ImageSearchOpts,
|
||||
pool: Pool,
|
||||
db: Pool,
|
||||
tree: Tree,
|
||||
api_key: String,
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
let db = early_return!(pool.get().await);
|
||||
|
||||
let image_remaining = rate_limit!(&api_key, &db, image_limit, "image");
|
||||
let hash_remaining = rate_limit!(&api_key, &db, hash_limit, "hash");
|
||||
|
||||
@ -126,7 +117,7 @@ pub async fn search_image(
|
||||
let mut items = {
|
||||
if opts.search_type == Some(ImageSearchType::Force) {
|
||||
image_query(
|
||||
pool.clone(),
|
||||
db.clone(),
|
||||
tree.clone(),
|
||||
vec![num],
|
||||
10,
|
||||
@ -136,7 +127,7 @@ pub async fn search_image(
|
||||
.unwrap()
|
||||
} else {
|
||||
let results = image_query(
|
||||
pool.clone(),
|
||||
db.clone(),
|
||||
tree.clone(),
|
||||
vec![num],
|
||||
0,
|
||||
@ -146,7 +137,7 @@ pub async fn search_image(
|
||||
.unwrap();
|
||||
if results.is_empty() && opts.search_type != Some(ImageSearchType::Exact) {
|
||||
image_query(
|
||||
pool.clone(),
|
||||
db.clone(),
|
||||
tree.clone(),
|
||||
vec![num],
|
||||
10,
|
||||
@ -194,10 +185,8 @@ pub async fn stream_image(
|
||||
tree: Tree,
|
||||
api_key: String,
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
let db = early_return!(pool.get().await);
|
||||
|
||||
rate_limit!(&api_key, &db, image_limit, "image", 2);
|
||||
rate_limit!(&api_key, &db, hash_limit, "hash");
|
||||
rate_limit!(&api_key, &pool, image_limit, "image", 2);
|
||||
rate_limit!(&api_key, &pool, hash_limit, "hash");
|
||||
|
||||
let (num, hash) = hash_input(form).await;
|
||||
|
||||
@ -220,7 +209,7 @@ pub async fn stream_image(
|
||||
|
||||
#[allow(clippy::unnecessary_wraps)]
|
||||
fn sse_matches(
|
||||
matches: Result<Vec<File>, tokio_postgres::Error>,
|
||||
matches: Result<Vec<File>, sqlx::Error>,
|
||||
) -> Result<warp::sse::Event, core::convert::Infallible> {
|
||||
let items = matches.unwrap();
|
||||
|
||||
@ -234,7 +223,6 @@ pub async fn search_hashes(
|
||||
api_key: String,
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
let pool = db.clone();
|
||||
let db = early_return!(db.get().await);
|
||||
|
||||
let hashes: Vec<i64> = opts
|
||||
.hashes
|
||||
@ -280,64 +268,95 @@ pub async fn search_file(
|
||||
db: Pool,
|
||||
api_key: String,
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
let db = early_return!(db.get().await);
|
||||
use sqlx::Row;
|
||||
|
||||
let file_remaining = rate_limit!(&api_key, &db, name_limit, "file");
|
||||
|
||||
let (filter, val): (&'static str, &(dyn tokio_postgres::types::ToSql + Sync)) =
|
||||
if let Some(ref id) = opts.id {
|
||||
("file_id = $1", id)
|
||||
} else if let Some(ref name) = opts.name {
|
||||
("lower(filename) = lower($1)", name)
|
||||
} else if let Some(ref url) = opts.url {
|
||||
("lower(url) = lower($1)", url)
|
||||
} else {
|
||||
return Ok(Box::new(Error::InvalidData));
|
||||
};
|
||||
let query = if let Some(ref id) = opts.id {
|
||||
sqlx::query(
|
||||
"SELECT
|
||||
submission.id,
|
||||
submission.url,
|
||||
submission.filename,
|
||||
submission.file_id,
|
||||
artist.name,
|
||||
hashes.id hash_id
|
||||
FROM
|
||||
submission
|
||||
JOIN artist
|
||||
ON artist.id = submission.artist_id
|
||||
JOIN hashes
|
||||
ON hashes.furaffinity_id = submission.id
|
||||
WHERE
|
||||
file_id = $1
|
||||
LIMIT 10",
|
||||
)
|
||||
.bind(id)
|
||||
} else if let Some(ref name) = opts.name {
|
||||
sqlx::query(
|
||||
"SELECT
|
||||
submission.id,
|
||||
submission.url,
|
||||
submission.filename,
|
||||
submission.file_id,
|
||||
artist.name,
|
||||
hashes.id hash_id
|
||||
FROM
|
||||
submission
|
||||
JOIN artist
|
||||
ON artist.id = submission.artist_id
|
||||
JOIN hashes
|
||||
ON hashes.furaffinity_id = submission.id
|
||||
WHERE
|
||||
lower(filename) = lower($1)
|
||||
LIMIT 10",
|
||||
)
|
||||
.bind(name)
|
||||
} else if let Some(ref url) = opts.url {
|
||||
sqlx::query(
|
||||
"SELECT
|
||||
submission.id,
|
||||
submission.url,
|
||||
submission.filename,
|
||||
submission.file_id,
|
||||
artist.name,
|
||||
hashes.id hash_id
|
||||
FROM
|
||||
submission
|
||||
JOIN artist
|
||||
ON artist.id = submission.artist_id
|
||||
JOIN hashes
|
||||
ON hashes.furaffinity_id = submission.id
|
||||
WHERE
|
||||
lower(url) = lower($1)
|
||||
LIMIT 10",
|
||||
)
|
||||
.bind(url)
|
||||
} else {
|
||||
return Ok(Box::new(Error::InvalidData));
|
||||
};
|
||||
|
||||
let query = format!(
|
||||
"SELECT
|
||||
submission.id,
|
||||
submission.url,
|
||||
submission.filename,
|
||||
submission.file_id,
|
||||
artist.name,
|
||||
hashes.id hash_id
|
||||
FROM
|
||||
submission
|
||||
JOIN artist
|
||||
ON artist.id = submission.artist_id
|
||||
JOIN hashes
|
||||
ON hashes.furaffinity_id = submission.id
|
||||
WHERE
|
||||
{}
|
||||
LIMIT 10",
|
||||
filter
|
||||
);
|
||||
let matches: Result<Vec<File>, _> = query
|
||||
.map(|row| File {
|
||||
id: row.get("hash_id"),
|
||||
site_id: row.get::<i32, _>("id") as i64,
|
||||
site_id_str: row.get::<i32, _>("id").to_string(),
|
||||
url: row.get("url"),
|
||||
filename: row.get("filename"),
|
||||
artists: row
|
||||
.get::<Option<String>, _>("name")
|
||||
.map(|artist| vec![artist]),
|
||||
distance: None,
|
||||
hash: None,
|
||||
site_info: Some(SiteInfo::FurAffinity(FurAffinityFile {
|
||||
file_id: row.get("file_id"),
|
||||
})),
|
||||
searched_hash: None,
|
||||
})
|
||||
.fetch_all(&db)
|
||||
.await;
|
||||
|
||||
let matches: Vec<_> = early_return!(
|
||||
db.query::<str>(&*query, &[val])
|
||||
.instrument(span!(tracing::Level::TRACE, "waiting for db"))
|
||||
.await
|
||||
)
|
||||
.into_iter()
|
||||
.map(|row| File {
|
||||
id: row.get("hash_id"),
|
||||
site_id: row.get::<&str, i32>("id") as i64,
|
||||
site_id_str: row.get::<&str, i32>("id").to_string(),
|
||||
url: row.get("url"),
|
||||
filename: row.get("filename"),
|
||||
artists: row
|
||||
.get::<&str, Option<String>>("name")
|
||||
.map(|artist| vec![artist]),
|
||||
distance: None,
|
||||
hash: None,
|
||||
site_info: Some(SiteInfo::FurAffinity(FurAffinityFile {
|
||||
file_id: row.get("file_id"),
|
||||
})),
|
||||
searched_hash: None,
|
||||
})
|
||||
.collect();
|
||||
let matches = early_return!(matches);
|
||||
|
||||
let resp = warp::http::Response::builder()
|
||||
.header("x-rate-limit-total-file", file_remaining.1.to_string())
|
||||
@ -350,17 +369,13 @@ pub async fn search_file(
|
||||
}
|
||||
|
||||
pub async fn check_handle(opts: HandleOpts, db: Pool) -> Result<Box<dyn Reply>, Rejection> {
|
||||
let db = early_return!(db.get().await);
|
||||
|
||||
let exists = if let Some(handle) = opts.twitter {
|
||||
!early_return!(
|
||||
db.query(
|
||||
"SELECT 1 FROM twitter_user WHERE lower(data->>'screen_name') = lower($1)",
|
||||
&[&handle],
|
||||
)
|
||||
let result = sqlx::query_scalar!("SELECT exists(SELECT 1 FROM twitter_user WHERE lower(data->>'screen_name') = lower($1))", handle)
|
||||
.fetch_optional(&db)
|
||||
.await
|
||||
)
|
||||
.is_empty()
|
||||
.map(|row| row.flatten().unwrap_or(false));
|
||||
|
||||
early_return!(result)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
@ -370,7 +385,7 @@ pub async fn check_handle(opts: HandleOpts, db: Pool) -> Result<Box<dyn Reply>,
|
||||
|
||||
pub async fn search_image_by_url(
|
||||
opts: UrlSearchOpts,
|
||||
pool: Pool,
|
||||
db: Pool,
|
||||
tree: Tree,
|
||||
api_key: String,
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
@ -378,8 +393,6 @@ pub async fn search_image_by_url(
|
||||
|
||||
let url = opts.url;
|
||||
|
||||
let db = early_return!(pool.get().await);
|
||||
|
||||
let image_remaining = rate_limit!(&api_key, &db, image_limit, "image");
|
||||
let hash_remaining = rate_limit!(&api_key, &db, hash_limit, "hash");
|
||||
|
||||
@ -424,15 +437,9 @@ pub async fn search_image_by_url(
|
||||
let hash: [u8; 8] = hash.as_bytes().try_into().unwrap();
|
||||
let num = i64::from_be_bytes(hash);
|
||||
|
||||
let results = image_query(
|
||||
pool.clone(),
|
||||
tree.clone(),
|
||||
vec![num],
|
||||
3,
|
||||
Some(hash.to_vec()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = image_query(db.clone(), tree.clone(), vec![num], 3, Some(hash.to_vec()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let resp = warp::http::Response::builder()
|
||||
.header("x-image-hash", num.to_string())
|
||||
|
152
src/main.rs
152
src/main.rs
@ -1,6 +1,5 @@
|
||||
#![recursion_limit = "256"]
|
||||
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use warp::Filter;
|
||||
@ -12,7 +11,7 @@ mod types;
|
||||
mod utils;
|
||||
|
||||
type Tree = Arc<RwLock<bk_tree::BKTree<Node, Hamming>>>;
|
||||
type Pool = bb8::Pool<bb8_postgres::PostgresConnectionManager<tokio_postgres::NoTls>>;
|
||||
type Pool = sqlx::PgPool;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Node {
|
||||
@ -38,93 +37,15 @@ impl bk_tree::Metric<Node> for Hamming {
|
||||
async fn main() {
|
||||
configure_tracing();
|
||||
|
||||
let s = std::env::var("POSTGRES_DSN").expect("Missing POSTGRES_DSN");
|
||||
let s = std::env::var("DATABASE_URL").expect("Missing DATABASE_URL");
|
||||
|
||||
let manager = bb8_postgres::PostgresConnectionManager::new(
|
||||
tokio_postgres::Config::from_str(&s).expect("Invalid POSTGRES_DSN"),
|
||||
tokio_postgres::NoTls,
|
||||
);
|
||||
|
||||
let db_pool = bb8::Pool::builder()
|
||||
.build(manager)
|
||||
let db_pool = sqlx::PgPool::connect(&s)
|
||||
.await
|
||||
.expect("Unable to build Postgres pool");
|
||||
.expect("Unable to create Postgres pool");
|
||||
|
||||
let tree: Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(Hamming)));
|
||||
|
||||
let mut max_id = 0;
|
||||
|
||||
let conn = db_pool.get().await.unwrap();
|
||||
let mut lock = tree.write().await;
|
||||
conn.query("SELECT id, hash FROM hashes", &[])
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.for_each(|row| {
|
||||
let id: i32 = row.get(0);
|
||||
let hash: i64 = row.get(1);
|
||||
let bytes = hash.to_be_bytes();
|
||||
|
||||
if id > max_id {
|
||||
max_id = id;
|
||||
}
|
||||
|
||||
lock.add(Node { id, hash: bytes });
|
||||
});
|
||||
drop(lock);
|
||||
drop(conn);
|
||||
|
||||
let tree_clone = tree.clone();
|
||||
let pool_clone = db_pool.clone();
|
||||
tokio::spawn(async move {
|
||||
use futures::StreamExt;
|
||||
|
||||
let max_id = std::sync::atomic::AtomicI32::new(max_id);
|
||||
let tree = tree_clone;
|
||||
let pool = pool_clone;
|
||||
|
||||
let order = std::sync::atomic::Ordering::SeqCst;
|
||||
|
||||
let interval = async_stream::stream! {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
|
||||
|
||||
while let item = interval.tick().await {
|
||||
yield item;
|
||||
}
|
||||
};
|
||||
|
||||
interval
|
||||
.for_each(|_| async {
|
||||
tracing::debug!("Refreshing hashes");
|
||||
|
||||
let conn = pool.get().await.unwrap();
|
||||
let mut lock = tree.write().await;
|
||||
let id = max_id.load(order);
|
||||
|
||||
let mut count: usize = 0;
|
||||
|
||||
conn.query("SELECT id, hash FROM hashes WHERE hashes.id > $1", &[&id])
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.for_each(|row| {
|
||||
let id: i32 = row.get(0);
|
||||
let hash: i64 = row.get(1);
|
||||
let bytes = hash.to_be_bytes();
|
||||
|
||||
if id > max_id.load(order) {
|
||||
max_id.store(id, order);
|
||||
}
|
||||
|
||||
lock.add(Node { id, hash: bytes });
|
||||
|
||||
count += 1;
|
||||
});
|
||||
|
||||
tracing::trace!("Added {} new hashes", count);
|
||||
})
|
||||
.await;
|
||||
});
|
||||
load_updates(db_pool.clone(), tree.clone()).await;
|
||||
|
||||
let log = warp::log("fuzzysearch");
|
||||
let cors = warp::cors()
|
||||
@ -198,6 +119,69 @@ fn configure_tracing() {
|
||||
registry.init();
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct HashRow {
|
||||
id: i32,
|
||||
hash: i64,
|
||||
}
|
||||
|
||||
async fn create_tree(conn: &Pool) -> bk_tree::BKTree<Node, Hamming> {
|
||||
use futures::TryStreamExt;
|
||||
|
||||
let mut tree = bk_tree::BKTree::new(Hamming);
|
||||
|
||||
let mut rows = sqlx::query_as!(HashRow, "SELECT id, hash FROM hashes").fetch(conn);
|
||||
|
||||
while let Some(row) = rows.try_next().await.expect("Unable to get row") {
|
||||
tree.add(Node {
|
||||
id: row.id,
|
||||
hash: row.hash.to_be_bytes(),
|
||||
})
|
||||
}
|
||||
|
||||
tree
|
||||
}
|
||||
|
||||
async fn load_updates(conn: Pool, tree: Tree) {
|
||||
let mut listener = sqlx::postgres::PgListener::connect_with(&conn)
|
||||
.await
|
||||
.unwrap();
|
||||
listener.listen("fuzzysearch_hash_added").await.unwrap();
|
||||
|
||||
let new_tree = create_tree(&conn).await;
|
||||
let mut lock = tree.write().await;
|
||||
*lock = new_tree;
|
||||
drop(lock);
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
while let Some(notification) = listener
|
||||
.try_recv()
|
||||
.await
|
||||
.expect("Unable to recv notification")
|
||||
{
|
||||
let payload: HashRow = serde_json::from_str(notification.payload()).unwrap();
|
||||
tracing::debug!(id = payload.id, "Adding new hash to tree");
|
||||
|
||||
let mut lock = tree.write().await;
|
||||
lock.add(Node {
|
||||
id: payload.id,
|
||||
hash: payload.hash.to_be_bytes(),
|
||||
});
|
||||
drop(lock);
|
||||
}
|
||||
|
||||
tracing::error!("Lost connection to Postgres, recreating tree");
|
||||
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
|
||||
let new_tree = create_tree(&conn).await;
|
||||
let mut lock = tree.write().await;
|
||||
*lock = new_tree;
|
||||
drop(lock);
|
||||
tracing::info!("Replaced tree");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn get_hasher() -> img_hash::Hasher<[u8; 8]> {
|
||||
use img_hash::{HashAlg::Gradient, HasherConfig};
|
||||
|
||||
|
102
src/models.rs
102
src/models.rs
@ -1,44 +1,31 @@
|
||||
use crate::types::*;
|
||||
use crate::utils::extract_rows;
|
||||
use crate::{Pool, Tree};
|
||||
use tracing_futures::Instrument;
|
||||
|
||||
pub type Db<'a> =
|
||||
&'a bb8::PooledConnection<'a, bb8_postgres::PostgresConnectionManager<tokio_postgres::NoTls>>;
|
||||
|
||||
#[tracing::instrument(skip(db))]
|
||||
pub async fn lookup_api_key(key: &str, db: Db<'_>) -> Option<ApiKey> {
|
||||
let rows = db
|
||||
.query(
|
||||
"SELECT
|
||||
pub async fn lookup_api_key(key: &str, db: &sqlx::PgPool) -> Option<ApiKey> {
|
||||
sqlx::query_as!(
|
||||
ApiKey,
|
||||
"SELECT
|
||||
api_key.id,
|
||||
api_key.name_limit,
|
||||
api_key.image_limit,
|
||||
api_key.hash_limit,
|
||||
api_key.name,
|
||||
account.email
|
||||
account.email owner_email
|
||||
FROM
|
||||
api_key
|
||||
JOIN account
|
||||
ON account.id = api_key.user_id
|
||||
WHERE
|
||||
api_key.key = $1",
|
||||
&[&key],
|
||||
)
|
||||
.await
|
||||
.expect("Unable to query API keys");
|
||||
|
||||
match rows.into_iter().next() {
|
||||
Some(row) => Some(ApiKey {
|
||||
id: row.get(0),
|
||||
name_limit: row.get(1),
|
||||
image_limit: row.get(2),
|
||||
hash_limit: row.get(3),
|
||||
name: row.get(4),
|
||||
owner_email: row.get(5),
|
||||
}),
|
||||
_ => None,
|
||||
}
|
||||
api_key.key = $1
|
||||
",
|
||||
key
|
||||
)
|
||||
.fetch_optional(db)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(pool, tree))]
|
||||
@ -48,7 +35,7 @@ pub async fn image_query(
|
||||
hashes: Vec<i64>,
|
||||
distance: i64,
|
||||
hash: Option<Vec<u8>>,
|
||||
) -> Result<Vec<File>, tokio_postgres::Error> {
|
||||
) -> Result<Vec<File>, sqlx::Error> {
|
||||
let mut results = image_query_sync(pool, tree, hashes, distance, hash);
|
||||
let mut matches = Vec::new();
|
||||
|
||||
@ -66,19 +53,26 @@ pub fn image_query_sync(
|
||||
hashes: Vec<i64>,
|
||||
distance: i64,
|
||||
hash: Option<Vec<u8>>,
|
||||
) -> tokio::sync::mpsc::Receiver<Result<Vec<File>, tokio_postgres::Error>> {
|
||||
) -> tokio::sync::mpsc::Receiver<Result<Vec<File>, sqlx::Error>> {
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(50);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let db = pool.get().await.unwrap();
|
||||
let db = pool;
|
||||
|
||||
for query_hash in hashes {
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
|
||||
let node = crate::Node::query(query_hash.to_be_bytes());
|
||||
let lock = tree.read().await;
|
||||
let items = lock.find(&node, distance as u64);
|
||||
|
||||
for (_dist, item) in items {
|
||||
let query = db.query("SELECT
|
||||
for (dist, item) in items {
|
||||
if seen.contains(&item.id) {
|
||||
continue;
|
||||
}
|
||||
seen.insert(item.id);
|
||||
|
||||
let row = sqlx::query!("SELECT
|
||||
hashes.id,
|
||||
hashes.hash,
|
||||
hashes.furaffinity_id,
|
||||
@ -131,14 +125,44 @@ pub fn image_query_sync(
|
||||
tweet_media.hash <@ (hashes.hash, 0)
|
||||
LIMIT 1
|
||||
) tm ON hashes.twitter_id IS NOT NULL
|
||||
WHERE hashes.id = $1", &[&item.id]).await;
|
||||
let rows = query.map(|rows| {
|
||||
extract_rows(rows, hash.as_deref()).into_iter().map(|mut file| {
|
||||
file.searched_hash = Some(query_hash);
|
||||
file
|
||||
}).collect()
|
||||
});
|
||||
tx.send(rows).await.unwrap();
|
||||
WHERE hashes.id = $1", item.id).map(|row| {
|
||||
let (site_id, site_info) = if let Some(fa_id) = row.furaffinity_id {
|
||||
(
|
||||
fa_id as i64,
|
||||
Some(SiteInfo::FurAffinity(FurAffinityFile {
|
||||
file_id: row.file_id.unwrap(),
|
||||
}))
|
||||
)
|
||||
} else if let Some(e621_id) = row.e621_id {
|
||||
(
|
||||
e621_id as i64,
|
||||
Some(SiteInfo::E621(E621File {
|
||||
sources: row.sources,
|
||||
}))
|
||||
)
|
||||
} else if let Some(twitter_id) = row.twitter_id {
|
||||
(twitter_id, Some(SiteInfo::Twitter))
|
||||
} else {
|
||||
(-1, None)
|
||||
};
|
||||
|
||||
let file = File {
|
||||
id: row.id,
|
||||
site_id,
|
||||
site_info,
|
||||
site_id_str: site_id.to_string(),
|
||||
url: row.url.unwrap_or_default(),
|
||||
hash: Some(row.hash),
|
||||
distance: Some(dist),
|
||||
artists: row.artists,
|
||||
filename: row.filename.unwrap_or_default(),
|
||||
searched_hash: Some(query_hash),
|
||||
};
|
||||
|
||||
vec![file]
|
||||
}).fetch_one(&db).await;
|
||||
|
||||
tx.send(row).await.unwrap();
|
||||
}
|
||||
}
|
||||
}.in_current_span());
|
||||
|
@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize};
|
||||
pub struct ApiKey {
|
||||
pub id: i32,
|
||||
pub name: Option<String>,
|
||||
pub owner_email: Option<String>,
|
||||
pub owner_email: String,
|
||||
pub name_limit: i16,
|
||||
pub image_limit: i16,
|
||||
pub hash_limit: i16,
|
||||
|
85
src/utils.rs
85
src/utils.rs
@ -1,4 +1,3 @@
|
||||
use crate::models::Db;
|
||||
use crate::types::*;
|
||||
|
||||
#[macro_export]
|
||||
@ -51,30 +50,31 @@ macro_rules! early_return {
|
||||
/// joined requests.
|
||||
#[tracing::instrument(skip(db))]
|
||||
pub async fn update_rate_limit(
|
||||
db: Db<'_>,
|
||||
db: &sqlx::PgPool,
|
||||
key_id: i32,
|
||||
key_group_limit: i16,
|
||||
group_name: &'static str,
|
||||
incr_by: i16,
|
||||
) -> Result<RateLimit, tokio_postgres::Error> {
|
||||
) -> Result<RateLimit, sqlx::Error> {
|
||||
let now = chrono::Utc::now();
|
||||
let timestamp = now.timestamp();
|
||||
let time_window = timestamp - (timestamp % 60);
|
||||
|
||||
let rows = db
|
||||
.query(
|
||||
"INSERT INTO
|
||||
rate_limit (api_key_id, time_window, group_name, count)
|
||||
VALUES
|
||||
($1, $2, $3, $4)
|
||||
ON CONFLICT ON CONSTRAINT unique_window
|
||||
DO UPDATE set count = rate_limit.count + $4
|
||||
RETURNING rate_limit.count",
|
||||
&[&key_id, &time_window, &group_name, &incr_by],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let count: i16 = rows[0].get(0);
|
||||
let count: i16 = sqlx::query_scalar!(
|
||||
"INSERT INTO
|
||||
rate_limit (api_key_id, time_window, group_name, count)
|
||||
VALUES
|
||||
($1, $2, $3, $4)
|
||||
ON CONFLICT ON CONSTRAINT unique_window
|
||||
DO UPDATE set count = rate_limit.count + $4
|
||||
RETURNING rate_limit.count",
|
||||
key_id,
|
||||
time_window,
|
||||
group_name,
|
||||
incr_by
|
||||
)
|
||||
.fetch_one(db)
|
||||
.await?;
|
||||
|
||||
if count > key_group_limit {
|
||||
Ok(RateLimit::Limited)
|
||||
@ -85,54 +85,3 @@ pub async fn update_rate_limit(
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_rows(
|
||||
rows: Vec<tokio_postgres::Row>,
|
||||
hash: Option<&[u8]>,
|
||||
) -> impl IntoIterator<Item = File> + '_ {
|
||||
rows.into_iter().map(move |row| {
|
||||
let dbhash: i64 = row.get("hash");
|
||||
let dbbytes = dbhash.to_be_bytes();
|
||||
|
||||
let (furaffinity_id, e621_id, twitter_id): (Option<i32>, Option<i32>, Option<i64>) = (
|
||||
row.get("furaffinity_id"),
|
||||
row.get("e621_id"),
|
||||
row.get("twitter_id"),
|
||||
);
|
||||
|
||||
let (site_id, site_info) = if let Some(fa_id) = furaffinity_id {
|
||||
(
|
||||
fa_id as i64,
|
||||
Some(SiteInfo::FurAffinity(FurAffinityFile {
|
||||
file_id: row.get("file_id"),
|
||||
})),
|
||||
)
|
||||
} else if let Some(e6_id) = e621_id {
|
||||
(
|
||||
e6_id as i64,
|
||||
Some(SiteInfo::E621(E621File {
|
||||
sources: row.get("sources"),
|
||||
})),
|
||||
)
|
||||
} else if let Some(t_id) = twitter_id {
|
||||
(t_id, Some(SiteInfo::Twitter))
|
||||
} else {
|
||||
(-1, None)
|
||||
};
|
||||
|
||||
File {
|
||||
id: row.get("id"),
|
||||
site_id,
|
||||
site_info,
|
||||
site_id_str: site_id.to_string(),
|
||||
url: row.get("url"),
|
||||
hash: Some(dbhash),
|
||||
distance: hash
|
||||
.map(|hash| hamming::distance_fast(&dbbytes, &hash).ok())
|
||||
.flatten(),
|
||||
artists: row.get("artists"),
|
||||
filename: row.get("filename"),
|
||||
searched_hash: None,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user