Use NOTIFY/LISTEN instead of polling for updates (#3)

* Use NOTIFY/LISTEN instead of polling for updates.

* Allow different distances for multiple hashes.
This commit is contained in:
Syfaro 2021-02-17 16:30:05 -05:00 committed by GitHub
parent 06a1c7b466
commit 908cda8ce9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 891 additions and 529 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
/target /target
.env

678
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -28,9 +28,7 @@ warp = "0.3"
reqwest = "0.11" reqwest = "0.11"
hyper = "0.14" hyper = "0.14"
tokio-postgres = "0.7" sqlx = { version = "0.5", features = ["runtime-tokio-native-tls", "postgres", "macros", "json", "offline"] }
bb8 = "0.7"
bb8-postgres = "0.7"
img_hash = "3" img_hash = "3"
image = "0.23" image = "0.23"

View File

@ -1,5 +1,6 @@
FROM rust:1-slim AS builder FROM rust:1-slim AS builder
WORKDIR /src WORKDIR /src
ENV SQLX_OFFLINE=true
RUN apt-get update -y && apt-get install -y libssl-dev pkg-config RUN apt-get update -y && apt-get install -y libssl-dev pkg-config
COPY . . COPY . .
RUN cargo install --root / --path . RUN cargo install --root / --path .

194
sqlx-data.json Normal file
View File

@ -0,0 +1,194 @@
{
"db": "PostgreSQL",
"1984ce60f052d6a29638f8e05b35671b8edfbf273783d4b843ebd35cbb8a391f": {
"query": "INSERT INTO\n rate_limit (api_key_id, time_window, group_name, count)\n VALUES\n ($1, $2, $3, $4)\n ON CONFLICT ON CONSTRAINT unique_window\n DO UPDATE set count = rate_limit.count + $4\n RETURNING rate_limit.count",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "count",
"type_info": "Int2"
}
],
"parameters": {
"Left": [
"Int4",
"Int8",
"Text",
"Int2"
]
},
"nullable": [
false
]
}
},
"659ee9ddc1c5ccd42ba9dc1617440544c30ece449ba3ba7f9d39f447b8af3cfe": {
"query": "SELECT\n api_key.id,\n api_key.name_limit,\n api_key.image_limit,\n api_key.hash_limit,\n api_key.name,\n account.email owner_email\n FROM\n api_key\n JOIN account\n ON account.id = api_key.user_id\n WHERE\n api_key.key = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int4"
},
{
"ordinal": 1,
"name": "name_limit",
"type_info": "Int2"
},
{
"ordinal": 2,
"name": "image_limit",
"type_info": "Int2"
},
{
"ordinal": 3,
"name": "hash_limit",
"type_info": "Int2"
},
{
"ordinal": 4,
"name": "name",
"type_info": "Varchar"
},
{
"ordinal": 5,
"name": "owner_email",
"type_info": "Varchar"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false,
false,
false,
true,
false
]
}
},
"6b8d304fc40fa539ae671e6e24e7978ad271cb7a1cafb20fc4b4096a958d790f": {
"query": "SELECT exists(SELECT 1 FROM twitter_user WHERE lower(data->>'screen_name') = lower($1))",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "exists",
"type_info": "Bool"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
null
]
}
},
"f4608ccaf739d36649cdbc5297177a989cc7763006d28c97e219bb708930972a": {
"query": "SELECT\n hashes.id,\n hashes.hash,\n hashes.furaffinity_id,\n hashes.e621_id,\n hashes.twitter_id,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (f.url)\n WHEN e621_id IS NOT NULL THEN (e.data->'file'->>'url')\n WHEN twitter_id IS NOT NULL THEN (tm.url)\n END url,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (f.filename)\n WHEN e621_id IS NOT NULL THEN ((e.data->'file'->>'md5') || '.' || (e.data->'file'->>'ext'))\n WHEN twitter_id IS NOT NULL THEN (SELECT split_part(split_part(tm.url, '/', 5), ':', 1))\n END filename,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (ARRAY(SELECT f.name))\n WHEN e621_id IS NOT NULL THEN ARRAY(SELECT jsonb_array_elements_text(e.data->'tags'->'artist'))\n WHEN twitter_id IS NOT NULL THEN ARRAY(SELECT tw.data->'user'->>'screen_name')\n END artists,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (f.file_id)\n END file_id,\n CASE\n WHEN e621_id IS NOT NULL THEN ARRAY(SELECT jsonb_array_elements_text(e.data->'sources'))\n END sources\n FROM\n hashes\n LEFT JOIN LATERAL (\n SELECT *\n FROM submission\n JOIN artist ON submission.artist_id = artist.id\n WHERE submission.id = hashes.furaffinity_id\n ) f ON hashes.furaffinity_id IS NOT NULL\n LEFT JOIN LATERAL (\n SELECT *\n FROM e621\n WHERE e621.id = hashes.e621_id\n ) e ON hashes.e621_id IS NOT NULL\n LEFT JOIN LATERAL (\n SELECT *\n FROM tweet\n WHERE tweet.id = hashes.twitter_id\n ) tw ON hashes.twitter_id IS NOT NULL\n LEFT JOIN LATERAL (\n SELECT *\n FROM tweet_media\n WHERE\n tweet_media.tweet_id = hashes.twitter_id AND\n tweet_media.hash <@ (hashes.hash, 0)\n LIMIT 1\n ) tm ON hashes.twitter_id IS NOT NULL\n WHERE hashes.id = $1",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int4"
},
{
"ordinal": 1,
"name": "hash",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "furaffinity_id",
"type_info": "Int4"
},
{
"ordinal": 3,
"name": "e621_id",
"type_info": "Int4"
},
{
"ordinal": 4,
"name": "twitter_id",
"type_info": "Int8"
},
{
"ordinal": 5,
"name": "url",
"type_info": "Text"
},
{
"ordinal": 6,
"name": "filename",
"type_info": "Text"
},
{
"ordinal": 7,
"name": "artists",
"type_info": "TextArray"
},
{
"ordinal": 8,
"name": "file_id",
"type_info": "Int4"
},
{
"ordinal": 9,
"name": "sources",
"type_info": "TextArray"
}
],
"parameters": {
"Left": [
"Int4"
]
},
"nullable": [
false,
false,
true,
true,
true,
null,
null,
null,
null,
null
]
}
},
"fe60be66b2d8a8f02b3bfe06d1f0e57e4bb07e80cba1b379a5f17f6cbd8b075c": {
"query": "SELECT id, hash FROM hashes",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int4"
},
{
"ordinal": 1,
"name": "hash",
"type_info": "Int8"
}
],
"parameters": {
"Left": []
},
"nullable": [
false,
false
]
}
}
}

View File

@ -8,8 +8,7 @@ use warp::{Rejection, Reply};
#[derive(Debug)] #[derive(Debug)]
enum Error { enum Error {
Bb8(bb8::RunError<tokio_postgres::Error>), Postgres(sqlx::Error),
Postgres(tokio_postgres::Error),
Reqwest(reqwest::Error), Reqwest(reqwest::Error),
InvalidData, InvalidData,
InvalidImage, InvalidImage,
@ -20,7 +19,7 @@ enum Error {
impl warp::Reply for Error { impl warp::Reply for Error {
fn into_response(self) -> warp::reply::Response { fn into_response(self) -> warp::reply::Response {
let msg = match self { let msg = match self {
Error::Bb8(_) | Error::Postgres(_) | Error::Reqwest(_) => ErrorMessage { Error::Postgres(_) | Error::Reqwest(_) => ErrorMessage {
code: 500, code: 500,
message: "Internal server error".to_string(), message: "Internal server error".to_string(),
}, },
@ -51,14 +50,8 @@ impl warp::Reply for Error {
} }
} }
impl From<bb8::RunError<tokio_postgres::Error>> for Error { impl From<sqlx::Error> for Error {
fn from(err: bb8::RunError<tokio_postgres::Error>) -> Self { fn from(err: sqlx::Error) -> Self {
Error::Bb8(err)
}
}
impl From<tokio_postgres::Error> for Error {
fn from(err: tokio_postgres::Error) -> Self {
Error::Postgres(err) Error::Postgres(err)
} }
} }
@ -112,12 +105,10 @@ async fn hash_input(form: warp::multipart::FormData) -> (i64, img_hash::ImageHas
pub async fn search_image( pub async fn search_image(
form: warp::multipart::FormData, form: warp::multipart::FormData,
opts: ImageSearchOpts, opts: ImageSearchOpts,
pool: Pool, db: Pool,
tree: Tree, tree: Tree,
api_key: String, api_key: String,
) -> Result<Box<dyn Reply>, Rejection> { ) -> Result<Box<dyn Reply>, Rejection> {
let db = early_return!(pool.get().await);
let image_remaining = rate_limit!(&api_key, &db, image_limit, "image"); let image_remaining = rate_limit!(&api_key, &db, image_limit, "image");
let hash_remaining = rate_limit!(&api_key, &db, hash_limit, "hash"); let hash_remaining = rate_limit!(&api_key, &db, hash_limit, "hash");
@ -126,7 +117,7 @@ pub async fn search_image(
let mut items = { let mut items = {
if opts.search_type == Some(ImageSearchType::Force) { if opts.search_type == Some(ImageSearchType::Force) {
image_query( image_query(
pool.clone(), db.clone(),
tree.clone(), tree.clone(),
vec![num], vec![num],
10, 10,
@ -136,7 +127,7 @@ pub async fn search_image(
.unwrap() .unwrap()
} else { } else {
let results = image_query( let results = image_query(
pool.clone(), db.clone(),
tree.clone(), tree.clone(),
vec![num], vec![num],
0, 0,
@ -146,7 +137,7 @@ pub async fn search_image(
.unwrap(); .unwrap();
if results.is_empty() && opts.search_type != Some(ImageSearchType::Exact) { if results.is_empty() && opts.search_type != Some(ImageSearchType::Exact) {
image_query( image_query(
pool.clone(), db.clone(),
tree.clone(), tree.clone(),
vec![num], vec![num],
10, 10,
@ -194,10 +185,8 @@ pub async fn stream_image(
tree: Tree, tree: Tree,
api_key: String, api_key: String,
) -> Result<Box<dyn Reply>, Rejection> { ) -> Result<Box<dyn Reply>, Rejection> {
let db = early_return!(pool.get().await); rate_limit!(&api_key, &pool, image_limit, "image", 2);
rate_limit!(&api_key, &pool, hash_limit, "hash");
rate_limit!(&api_key, &db, image_limit, "image", 2);
rate_limit!(&api_key, &db, hash_limit, "hash");
let (num, hash) = hash_input(form).await; let (num, hash) = hash_input(form).await;
@ -220,7 +209,7 @@ pub async fn stream_image(
#[allow(clippy::unnecessary_wraps)] #[allow(clippy::unnecessary_wraps)]
fn sse_matches( fn sse_matches(
matches: Result<Vec<File>, tokio_postgres::Error>, matches: Result<Vec<File>, sqlx::Error>,
) -> Result<warp::sse::Event, core::convert::Infallible> { ) -> Result<warp::sse::Event, core::convert::Infallible> {
let items = matches.unwrap(); let items = matches.unwrap();
@ -234,7 +223,6 @@ pub async fn search_hashes(
api_key: String, api_key: String,
) -> Result<Box<dyn Reply>, Rejection> { ) -> Result<Box<dyn Reply>, Rejection> {
let pool = db.clone(); let pool = db.clone();
let db = early_return!(db.get().await);
let hashes: Vec<i64> = opts let hashes: Vec<i64> = opts
.hashes .hashes
@ -280,22 +268,12 @@ pub async fn search_file(
db: Pool, db: Pool,
api_key: String, api_key: String,
) -> Result<Box<dyn Reply>, Rejection> { ) -> Result<Box<dyn Reply>, Rejection> {
let db = early_return!(db.get().await); use sqlx::Row;
let file_remaining = rate_limit!(&api_key, &db, name_limit, "file"); let file_remaining = rate_limit!(&api_key, &db, name_limit, "file");
let (filter, val): (&'static str, &(dyn tokio_postgres::types::ToSql + Sync)) = let query = if let Some(ref id) = opts.id {
if let Some(ref id) = opts.id { sqlx::query(
("file_id = $1", id)
} else if let Some(ref name) = opts.name {
("lower(filename) = lower($1)", name)
} else if let Some(ref url) = opts.url {
("lower(url) = lower($1)", url)
} else {
return Ok(Box::new(Error::InvalidData));
};
let query = format!(
"SELECT "SELECT
submission.id, submission.id,
submission.url, submission.url,
@ -310,25 +288,63 @@ pub async fn search_file(
JOIN hashes JOIN hashes
ON hashes.furaffinity_id = submission.id ON hashes.furaffinity_id = submission.id
WHERE WHERE
{} file_id = $1
LIMIT 10", LIMIT 10",
filter
);
let matches: Vec<_> = early_return!(
db.query::<str>(&*query, &[val])
.instrument(span!(tracing::Level::TRACE, "waiting for db"))
.await
) )
.into_iter() .bind(id)
} else if let Some(ref name) = opts.name {
sqlx::query(
"SELECT
submission.id,
submission.url,
submission.filename,
submission.file_id,
artist.name,
hashes.id hash_id
FROM
submission
JOIN artist
ON artist.id = submission.artist_id
JOIN hashes
ON hashes.furaffinity_id = submission.id
WHERE
lower(filename) = lower($1)
LIMIT 10",
)
.bind(name)
} else if let Some(ref url) = opts.url {
sqlx::query(
"SELECT
submission.id,
submission.url,
submission.filename,
submission.file_id,
artist.name,
hashes.id hash_id
FROM
submission
JOIN artist
ON artist.id = submission.artist_id
JOIN hashes
ON hashes.furaffinity_id = submission.id
WHERE
lower(url) = lower($1)
LIMIT 10",
)
.bind(url)
} else {
return Ok(Box::new(Error::InvalidData));
};
let matches: Result<Vec<File>, _> = query
.map(|row| File { .map(|row| File {
id: row.get("hash_id"), id: row.get("hash_id"),
site_id: row.get::<&str, i32>("id") as i64, site_id: row.get::<i32, _>("id") as i64,
site_id_str: row.get::<&str, i32>("id").to_string(), site_id_str: row.get::<i32, _>("id").to_string(),
url: row.get("url"), url: row.get("url"),
filename: row.get("filename"), filename: row.get("filename"),
artists: row artists: row
.get::<&str, Option<String>>("name") .get::<Option<String>, _>("name")
.map(|artist| vec![artist]), .map(|artist| vec![artist]),
distance: None, distance: None,
hash: None, hash: None,
@ -337,7 +353,10 @@ pub async fn search_file(
})), })),
searched_hash: None, searched_hash: None,
}) })
.collect(); .fetch_all(&db)
.await;
let matches = early_return!(matches);
let resp = warp::http::Response::builder() let resp = warp::http::Response::builder()
.header("x-rate-limit-total-file", file_remaining.1.to_string()) .header("x-rate-limit-total-file", file_remaining.1.to_string())
@ -350,17 +369,13 @@ pub async fn search_file(
} }
pub async fn check_handle(opts: HandleOpts, db: Pool) -> Result<Box<dyn Reply>, Rejection> { pub async fn check_handle(opts: HandleOpts, db: Pool) -> Result<Box<dyn Reply>, Rejection> {
let db = early_return!(db.get().await);
let exists = if let Some(handle) = opts.twitter { let exists = if let Some(handle) = opts.twitter {
!early_return!( let result = sqlx::query_scalar!("SELECT exists(SELECT 1 FROM twitter_user WHERE lower(data->>'screen_name') = lower($1))", handle)
db.query( .fetch_optional(&db)
"SELECT 1 FROM twitter_user WHERE lower(data->>'screen_name') = lower($1)",
&[&handle],
)
.await .await
) .map(|row| row.flatten().unwrap_or(false));
.is_empty()
early_return!(result)
} else { } else {
false false
}; };
@ -370,7 +385,7 @@ pub async fn check_handle(opts: HandleOpts, db: Pool) -> Result<Box<dyn Reply>,
pub async fn search_image_by_url( pub async fn search_image_by_url(
opts: UrlSearchOpts, opts: UrlSearchOpts,
pool: Pool, db: Pool,
tree: Tree, tree: Tree,
api_key: String, api_key: String,
) -> Result<Box<dyn Reply>, Rejection> { ) -> Result<Box<dyn Reply>, Rejection> {
@ -378,8 +393,6 @@ pub async fn search_image_by_url(
let url = opts.url; let url = opts.url;
let db = early_return!(pool.get().await);
let image_remaining = rate_limit!(&api_key, &db, image_limit, "image"); let image_remaining = rate_limit!(&api_key, &db, image_limit, "image");
let hash_remaining = rate_limit!(&api_key, &db, hash_limit, "hash"); let hash_remaining = rate_limit!(&api_key, &db, hash_limit, "hash");
@ -424,13 +437,7 @@ pub async fn search_image_by_url(
let hash: [u8; 8] = hash.as_bytes().try_into().unwrap(); let hash: [u8; 8] = hash.as_bytes().try_into().unwrap();
let num = i64::from_be_bytes(hash); let num = i64::from_be_bytes(hash);
let results = image_query( let results = image_query(db.clone(), tree.clone(), vec![num], 3, Some(hash.to_vec()))
pool.clone(),
tree.clone(),
vec![num],
3,
Some(hash.to_vec()),
)
.await .await
.unwrap(); .unwrap();

View File

@ -1,6 +1,5 @@
#![recursion_limit = "256"] #![recursion_limit = "256"]
use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use warp::Filter; use warp::Filter;
@ -12,7 +11,7 @@ mod types;
mod utils; mod utils;
type Tree = Arc<RwLock<bk_tree::BKTree<Node, Hamming>>>; type Tree = Arc<RwLock<bk_tree::BKTree<Node, Hamming>>>;
type Pool = bb8::Pool<bb8_postgres::PostgresConnectionManager<tokio_postgres::NoTls>>; type Pool = sqlx::PgPool;
#[derive(Debug)] #[derive(Debug)]
pub struct Node { pub struct Node {
@ -38,93 +37,15 @@ impl bk_tree::Metric<Node> for Hamming {
async fn main() { async fn main() {
configure_tracing(); configure_tracing();
let s = std::env::var("POSTGRES_DSN").expect("Missing POSTGRES_DSN"); let s = std::env::var("DATABASE_URL").expect("Missing DATABASE_URL");
let manager = bb8_postgres::PostgresConnectionManager::new( let db_pool = sqlx::PgPool::connect(&s)
tokio_postgres::Config::from_str(&s).expect("Invalid POSTGRES_DSN"),
tokio_postgres::NoTls,
);
let db_pool = bb8::Pool::builder()
.build(manager)
.await .await
.expect("Unable to build Postgres pool"); .expect("Unable to create Postgres pool");
let tree: Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(Hamming))); let tree: Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(Hamming)));
let mut max_id = 0; load_updates(db_pool.clone(), tree.clone()).await;
let conn = db_pool.get().await.unwrap();
let mut lock = tree.write().await;
conn.query("SELECT id, hash FROM hashes", &[])
.await
.unwrap()
.into_iter()
.for_each(|row| {
let id: i32 = row.get(0);
let hash: i64 = row.get(1);
let bytes = hash.to_be_bytes();
if id > max_id {
max_id = id;
}
lock.add(Node { id, hash: bytes });
});
drop(lock);
drop(conn);
let tree_clone = tree.clone();
let pool_clone = db_pool.clone();
tokio::spawn(async move {
use futures::StreamExt;
let max_id = std::sync::atomic::AtomicI32::new(max_id);
let tree = tree_clone;
let pool = pool_clone;
let order = std::sync::atomic::Ordering::SeqCst;
let interval = async_stream::stream! {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
while let item = interval.tick().await {
yield item;
}
};
interval
.for_each(|_| async {
tracing::debug!("Refreshing hashes");
let conn = pool.get().await.unwrap();
let mut lock = tree.write().await;
let id = max_id.load(order);
let mut count: usize = 0;
conn.query("SELECT id, hash FROM hashes WHERE hashes.id > $1", &[&id])
.await
.unwrap()
.into_iter()
.for_each(|row| {
let id: i32 = row.get(0);
let hash: i64 = row.get(1);
let bytes = hash.to_be_bytes();
if id > max_id.load(order) {
max_id.store(id, order);
}
lock.add(Node { id, hash: bytes });
count += 1;
});
tracing::trace!("Added {} new hashes", count);
})
.await;
});
let log = warp::log("fuzzysearch"); let log = warp::log("fuzzysearch");
let cors = warp::cors() let cors = warp::cors()
@ -198,6 +119,69 @@ fn configure_tracing() {
registry.init(); registry.init();
} }
#[derive(serde::Deserialize)]
struct HashRow {
id: i32,
hash: i64,
}
async fn create_tree(conn: &Pool) -> bk_tree::BKTree<Node, Hamming> {
use futures::TryStreamExt;
let mut tree = bk_tree::BKTree::new(Hamming);
let mut rows = sqlx::query_as!(HashRow, "SELECT id, hash FROM hashes").fetch(conn);
while let Some(row) = rows.try_next().await.expect("Unable to get row") {
tree.add(Node {
id: row.id,
hash: row.hash.to_be_bytes(),
})
}
tree
}
async fn load_updates(conn: Pool, tree: Tree) {
let mut listener = sqlx::postgres::PgListener::connect_with(&conn)
.await
.unwrap();
listener.listen("fuzzysearch_hash_added").await.unwrap();
let new_tree = create_tree(&conn).await;
let mut lock = tree.write().await;
*lock = new_tree;
drop(lock);
tokio::spawn(async move {
loop {
while let Some(notification) = listener
.try_recv()
.await
.expect("Unable to recv notification")
{
let payload: HashRow = serde_json::from_str(notification.payload()).unwrap();
tracing::debug!(id = payload.id, "Adding new hash to tree");
let mut lock = tree.write().await;
lock.add(Node {
id: payload.id,
hash: payload.hash.to_be_bytes(),
});
drop(lock);
}
tracing::error!("Lost connection to Postgres, recreating tree");
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
let new_tree = create_tree(&conn).await;
let mut lock = tree.write().await;
*lock = new_tree;
drop(lock);
tracing::info!("Replaced tree");
}
});
}
fn get_hasher() -> img_hash::Hasher<[u8; 8]> { fn get_hasher() -> img_hash::Hasher<[u8; 8]> {
use img_hash::{HashAlg::Gradient, HasherConfig}; use img_hash::{HashAlg::Gradient, HasherConfig};

View File

@ -1,44 +1,31 @@
use crate::types::*; use crate::types::*;
use crate::utils::extract_rows;
use crate::{Pool, Tree}; use crate::{Pool, Tree};
use tracing_futures::Instrument; use tracing_futures::Instrument;
pub type Db<'a> =
&'a bb8::PooledConnection<'a, bb8_postgres::PostgresConnectionManager<tokio_postgres::NoTls>>;
#[tracing::instrument(skip(db))] #[tracing::instrument(skip(db))]
pub async fn lookup_api_key(key: &str, db: Db<'_>) -> Option<ApiKey> { pub async fn lookup_api_key(key: &str, db: &sqlx::PgPool) -> Option<ApiKey> {
let rows = db sqlx::query_as!(
.query( ApiKey,
"SELECT "SELECT
api_key.id, api_key.id,
api_key.name_limit, api_key.name_limit,
api_key.image_limit, api_key.image_limit,
api_key.hash_limit, api_key.hash_limit,
api_key.name, api_key.name,
account.email account.email owner_email
FROM FROM
api_key api_key
JOIN account JOIN account
ON account.id = api_key.user_id ON account.id = api_key.user_id
WHERE WHERE
api_key.key = $1", api_key.key = $1
&[&key], ",
key
) )
.fetch_optional(db)
.await .await
.expect("Unable to query API keys"); .ok()
.flatten()
match rows.into_iter().next() {
Some(row) => Some(ApiKey {
id: row.get(0),
name_limit: row.get(1),
image_limit: row.get(2),
hash_limit: row.get(3),
name: row.get(4),
owner_email: row.get(5),
}),
_ => None,
}
} }
#[tracing::instrument(skip(pool, tree))] #[tracing::instrument(skip(pool, tree))]
@ -48,7 +35,7 @@ pub async fn image_query(
hashes: Vec<i64>, hashes: Vec<i64>,
distance: i64, distance: i64,
hash: Option<Vec<u8>>, hash: Option<Vec<u8>>,
) -> Result<Vec<File>, tokio_postgres::Error> { ) -> Result<Vec<File>, sqlx::Error> {
let mut results = image_query_sync(pool, tree, hashes, distance, hash); let mut results = image_query_sync(pool, tree, hashes, distance, hash);
let mut matches = Vec::new(); let mut matches = Vec::new();
@ -66,19 +53,26 @@ pub fn image_query_sync(
hashes: Vec<i64>, hashes: Vec<i64>,
distance: i64, distance: i64,
hash: Option<Vec<u8>>, hash: Option<Vec<u8>>,
) -> tokio::sync::mpsc::Receiver<Result<Vec<File>, tokio_postgres::Error>> { ) -> tokio::sync::mpsc::Receiver<Result<Vec<File>, sqlx::Error>> {
let (tx, rx) = tokio::sync::mpsc::channel(50); let (tx, rx) = tokio::sync::mpsc::channel(50);
tokio::spawn(async move { tokio::spawn(async move {
let db = pool.get().await.unwrap(); let db = pool;
for query_hash in hashes { for query_hash in hashes {
let mut seen = std::collections::HashSet::new();
let node = crate::Node::query(query_hash.to_be_bytes()); let node = crate::Node::query(query_hash.to_be_bytes());
let lock = tree.read().await; let lock = tree.read().await;
let items = lock.find(&node, distance as u64); let items = lock.find(&node, distance as u64);
for (_dist, item) in items { for (dist, item) in items {
let query = db.query("SELECT if seen.contains(&item.id) {
continue;
}
seen.insert(item.id);
let row = sqlx::query!("SELECT
hashes.id, hashes.id,
hashes.hash, hashes.hash,
hashes.furaffinity_id, hashes.furaffinity_id,
@ -131,14 +125,44 @@ pub fn image_query_sync(
tweet_media.hash <@ (hashes.hash, 0) tweet_media.hash <@ (hashes.hash, 0)
LIMIT 1 LIMIT 1
) tm ON hashes.twitter_id IS NOT NULL ) tm ON hashes.twitter_id IS NOT NULL
WHERE hashes.id = $1", &[&item.id]).await; WHERE hashes.id = $1", item.id).map(|row| {
let rows = query.map(|rows| { let (site_id, site_info) = if let Some(fa_id) = row.furaffinity_id {
extract_rows(rows, hash.as_deref()).into_iter().map(|mut file| { (
file.searched_hash = Some(query_hash); fa_id as i64,
file Some(SiteInfo::FurAffinity(FurAffinityFile {
}).collect() file_id: row.file_id.unwrap(),
}); }))
tx.send(rows).await.unwrap(); )
} else if let Some(e621_id) = row.e621_id {
(
e621_id as i64,
Some(SiteInfo::E621(E621File {
sources: row.sources,
}))
)
} else if let Some(twitter_id) = row.twitter_id {
(twitter_id, Some(SiteInfo::Twitter))
} else {
(-1, None)
};
let file = File {
id: row.id,
site_id,
site_info,
site_id_str: site_id.to_string(),
url: row.url.unwrap_or_default(),
hash: Some(row.hash),
distance: Some(dist),
artists: row.artists,
filename: row.filename.unwrap_or_default(),
searched_hash: Some(query_hash),
};
vec![file]
}).fetch_one(&db).await;
tx.send(row).await.unwrap();
} }
} }
}.in_current_span()); }.in_current_span());

View File

@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize};
pub struct ApiKey { pub struct ApiKey {
pub id: i32, pub id: i32,
pub name: Option<String>, pub name: Option<String>,
pub owner_email: Option<String>, pub owner_email: String,
pub name_limit: i16, pub name_limit: i16,
pub image_limit: i16, pub image_limit: i16,
pub hash_limit: i16, pub hash_limit: i16,

View File

@ -1,4 +1,3 @@
use crate::models::Db;
use crate::types::*; use crate::types::*;
#[macro_export] #[macro_export]
@ -51,18 +50,17 @@ macro_rules! early_return {
/// joined requests. /// joined requests.
#[tracing::instrument(skip(db))] #[tracing::instrument(skip(db))]
pub async fn update_rate_limit( pub async fn update_rate_limit(
db: Db<'_>, db: &sqlx::PgPool,
key_id: i32, key_id: i32,
key_group_limit: i16, key_group_limit: i16,
group_name: &'static str, group_name: &'static str,
incr_by: i16, incr_by: i16,
) -> Result<RateLimit, tokio_postgres::Error> { ) -> Result<RateLimit, sqlx::Error> {
let now = chrono::Utc::now(); let now = chrono::Utc::now();
let timestamp = now.timestamp(); let timestamp = now.timestamp();
let time_window = timestamp - (timestamp % 60); let time_window = timestamp - (timestamp % 60);
let rows = db let count: i16 = sqlx::query_scalar!(
.query(
"INSERT INTO "INSERT INTO
rate_limit (api_key_id, time_window, group_name, count) rate_limit (api_key_id, time_window, group_name, count)
VALUES VALUES
@ -70,12 +68,14 @@ pub async fn update_rate_limit(
ON CONFLICT ON CONSTRAINT unique_window ON CONFLICT ON CONSTRAINT unique_window
DO UPDATE set count = rate_limit.count + $4 DO UPDATE set count = rate_limit.count + $4
RETURNING rate_limit.count", RETURNING rate_limit.count",
&[&key_id, &time_window, &group_name, &incr_by], key_id,
time_window,
group_name,
incr_by
) )
.fetch_one(db)
.await?; .await?;
let count: i16 = rows[0].get(0);
if count > key_group_limit { if count > key_group_limit {
Ok(RateLimit::Limited) Ok(RateLimit::Limited)
} else { } else {
@ -85,54 +85,3 @@ pub async fn update_rate_limit(
))) )))
} }
} }
pub fn extract_rows(
rows: Vec<tokio_postgres::Row>,
hash: Option<&[u8]>,
) -> impl IntoIterator<Item = File> + '_ {
rows.into_iter().map(move |row| {
let dbhash: i64 = row.get("hash");
let dbbytes = dbhash.to_be_bytes();
let (furaffinity_id, e621_id, twitter_id): (Option<i32>, Option<i32>, Option<i64>) = (
row.get("furaffinity_id"),
row.get("e621_id"),
row.get("twitter_id"),
);
let (site_id, site_info) = if let Some(fa_id) = furaffinity_id {
(
fa_id as i64,
Some(SiteInfo::FurAffinity(FurAffinityFile {
file_id: row.get("file_id"),
})),
)
} else if let Some(e6_id) = e621_id {
(
e6_id as i64,
Some(SiteInfo::E621(E621File {
sources: row.get("sources"),
})),
)
} else if let Some(t_id) = twitter_id {
(t_id, Some(SiteInfo::Twitter))
} else {
(-1, None)
};
File {
id: row.get("id"),
site_id,
site_info,
site_id_str: site_id.to_string(),
url: row.get("url"),
hash: Some(dbhash),
distance: hash
.map(|hash| hamming::distance_fast(&dbbytes, &hash).ok())
.flatten(),
artists: row.get("artists"),
filename: row.get("filename"),
searched_hash: None,
}
})
}