diff --git a/.drone.yml b/.drone.yml index 76b7213..b9fee86 100644 --- a/.drone.yml +++ b/.drone.yml @@ -33,4 +33,8 @@ steps: from_secret: sccache_s3_endpoint SCCACHE_S3_USE_SSL: true +--- +kind: signature +hmac: 665dab5e07086669c4b215ed86faa0e1e63c495b0bf020099fb1edd33757618b + ... diff --git a/Cargo.lock b/Cargo.lock index 95b7a1e..94a596e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,6 +61,27 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" +[[package]] +name = "async-stream" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3670df70cbc01729f901f94c887814b3c68db038aad1329a418bae178bc5295c" +dependencies = [ + "async-stream-impl", + "futures-core", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3548b8efc9f8e8a5a0a2808c5bd8451a9031b9e5b879a79590304ae928b0a70" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-trait" version = "0.1.42" @@ -105,30 +126,6 @@ version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" -[[package]] -name = "bb8" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "374bba43fc924d90393ee7768e6f75d223a98307a488fe5bc34b66c3e96932a6" -dependencies = [ - "async-trait", - "futures", - "tokio 0.2.25", -] - -[[package]] -name = "bb8-postgres" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39a233af6ea3952e20d01863c87b4f6689b2f806249688b0908b5f02d4fa41ac" -dependencies = [ - "async-trait", - "bb8", - "futures", - "tokio 0.2.25", - "tokio-postgres 0.5.5", -] - [[package]] name = "bindgen" version = "0.54.0" @@ -835,25 +832,27 @@ dependencies = [ name = "fuzzysearch" version = "0.1.0" dependencies = [ - "anyhow", - "bb8", - "bb8-postgres", + "async-stream", "bk-tree", - "bytes 0.5.6", + "bytes 1.0.1", "chrono", "ffmpeg-next", "futures", - "futures-util", "fuzzysearch-common", "hamming", + "hyper 0.14.4", "image", "img_hash", "infer", + "lazy_static", "opentelemetry", "opentelemetry-jaeger", + "prometheus 0.11.0", + "reqwest 0.11.1", "serde", - "tokio 0.2.25", - "tokio-postgres 0.5.5", + "serde_json", + "sqlx 0.5.1", + "tokio 1.2.0", "tracing", "tracing-futures", "tracing-opentelemetry", @@ -1338,11 +1337,11 @@ checksum = "8906512588cd815b8f759fd0ac11de2a84c985c0f792f70df611e9325c270c1f" [[package]] name = "input_buffer" -version = "0.3.1" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19a8a95243d5a0398cae618ec29477c6e3cb631152be5c19481f80bc71559754" +checksum = "f97967975f448f1a7ddb12b0bc41069d09ed6a1c161a92687e057325db35d413" dependencies = [ - "bytes 0.5.6", + "bytes 1.0.1", ] [[package]] @@ -2530,7 +2529,7 @@ dependencies = [ "pin-project-lite 0.2.4", "serde", "serde_json", - "serde_urlencoded 0.7.0", + "serde_urlencoded", "tokio 0.2.25", "tokio-tls", "url", @@ -2565,7 +2564,7 @@ dependencies = [ "pin-project-lite 0.2.4", "serde", "serde_json", - "serde_urlencoded 0.7.0", + "serde_urlencoded", "tokio 1.2.0", "tokio-native-tls", "url", @@ -2824,18 +2823,6 @@ dependencies = [ "serde", ] -[[package]] -name = "serde_urlencoded" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ec5d77e2d4c73717816afac02670d5c4f534ea95ed430442cad02e7a6e32c97" -dependencies = [ - "dtoa", - "itoa", - "serde", - "url", -] - [[package]] name = "serde_urlencoded" version = "0.7.0" @@ -3516,14 +3503,14 @@ dependencies = [ [[package]] name = "tokio-tungstenite" -version = "0.11.0" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d9e878ad426ca286e4dcae09cbd4e1973a7f8987d97570e2469703dd7f5720c" +checksum = "e1a5f475f1b9d077ea1017ecbc60890fda8e54942d680ca0b1d2b47cfa2d861b" dependencies = [ "futures-util", "log", - "pin-project 0.4.27", - "tokio 0.2.25", + "pin-project 1.0.5", + "tokio 1.2.0", "tungstenite", ] @@ -3699,18 +3686,18 @@ checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" [[package]] name = "tungstenite" -version = "0.11.1" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0308d80d86700c5878b9ef6321f020f29b1bb9d5ff3cab25e75e23f3a492a23" +checksum = "8ada8297e8d70872fa9a551d93250a9f407beb9f37ef86494eb20012a2ff7c24" dependencies = [ - "base64 0.12.3", + "base64 0.13.0", "byteorder", - "bytes 0.5.6", + "bytes 1.0.1", "http", "httparse", "input_buffer", "log", - "rand 0.7.3", + "rand 0.8.3", "sha-1 0.9.4", "url", "utf-8", @@ -3812,12 +3799,6 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "urlencoding" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9232eb53352b4442e40d7900465dfc534e8cb2dc8f18656fcb2ac16112b5593" - [[package]] name = "utf-8" version = "0.7.5" @@ -3848,30 +3829,32 @@ dependencies = [ [[package]] name = "warp" -version = "0.2.5" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f41be6df54c97904af01aa23e613d4521eed7ab23537cede692d4058f6449407" +checksum = "3dafd0aac2818a94a34df0df1100a7356c493d8ede4393875fd0b5c51bb6bc80" dependencies = [ - "bytes 0.5.6", + "bytes 1.0.1", "futures", "headers", "http", - "hyper 0.13.10", + "hyper 0.14.4", "log", "mime", "mime_guess", "multipart", - "pin-project 0.4.27", + "percent-encoding", + "pin-project 1.0.5", "scoped-tls", "serde", "serde_json", - "serde_urlencoded 0.6.1", - "tokio 0.2.25", + "serde_urlencoded", + "tokio 1.2.0", + "tokio-stream", "tokio-tungstenite", + "tokio-util 0.6.3", "tower-service", "tracing", "tracing-futures", - "urlencoding", ] [[package]] diff --git a/fuzzysearch-common/src/types.rs b/fuzzysearch-common/src/types.rs index ac66095..50a5a03 100644 --- a/fuzzysearch-common/src/types.rs +++ b/fuzzysearch-common/src/types.rs @@ -1,5 +1,28 @@ use serde::Serialize; +#[derive(Debug, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum Rating { + General, + Mature, + Adult, +} + +impl std::str::FromStr for Rating { + type Err = &'static str; + + fn from_str(s: &str) -> Result { + let rating = match s { + "g" | "s" | "general" => Self::General, + "m" | "q" | "mature" => Self::Mature, + "a" | "e" | "adult" => Self::Adult, + _ => return Err("unknown rating"), + }; + + Ok(rating) + } +} + /// A general type for every result in a search. #[derive(Debug, Default, Serialize)] pub struct SearchResult { @@ -11,6 +34,7 @@ pub struct SearchResult { pub url: String, pub filename: String, pub artists: Option>, + pub rating: Option, #[serde(skip_serializing_if = "Option::is_none")] #[serde(flatten)] diff --git a/fuzzysearch-ingest-furaffinity/src/main.rs b/fuzzysearch-ingest-furaffinity/src/main.rs index 48eb1c3..eb261c0 100644 --- a/fuzzysearch-ingest-furaffinity/src/main.rs +++ b/fuzzysearch-ingest-furaffinity/src/main.rs @@ -165,7 +165,7 @@ async fn main() { tokio::spawn(async move { if let Err(e) = connection.await { - panic!(e); + panic!("postgres connection error: {:?}", e); } }); diff --git a/fuzzysearch/Cargo.toml b/fuzzysearch/Cargo.toml index 9ccdba9..1b8d0f8 100644 --- a/fuzzysearch/Cargo.toml +++ b/fuzzysearch/Cargo.toml @@ -9,29 +9,33 @@ tracing = "0.1" tracing-subscriber = "0.2" tracing-futures = "0.2" +prometheus = { version = "0.11", features = ["process"] } +lazy_static = "1" + opentelemetry = "0.6" opentelemetry-jaeger = "0.5" tracing-opentelemetry = "0.5" -tokio = { version = "0.2", features = ["full"] } -futures = "0.3" -futures-util = "0.3" +tokio = { version = "1", features = ["full"] } +async-stream = "0.3" + +futures = "0.3" -anyhow = "1" chrono = "0.4" -bytes = "0.5" -infer = { version = "0.3", default-features = false } +bytes = "1" serde = { version = "1", features = ["derive"] } -warp = "0.2" +serde_json = "1" -tokio-postgres = "0.5" -bb8 = "0.4" -bb8-postgres = "0.4" +warp = "0.3" +reqwest = "0.11" +hyper = "0.14" -image = "0.23" +sqlx = { version = "0.5", features = ["runtime-tokio-native-tls", "postgres", "macros", "json", "offline"] } + +infer = { version = "0.3", default-features = false } ffmpeg-next = "4" - +image = "0.23" img_hash = "3" hamming = "0.1" diff --git a/fuzzysearch/Dockerfile b/fuzzysearch/Dockerfile index 186616d..253ff7e 100644 --- a/fuzzysearch/Dockerfile +++ b/fuzzysearch/Dockerfile @@ -1,10 +1,14 @@ FROM rust:1-slim AS builder WORKDIR /src +ENV SQLX_OFFLINE=true +RUN apt-get update -y && apt-get install -y libssl-dev pkg-config COPY . . RUN cargo install --root / --path . FROM debian:buster-slim -EXPOSE 8080 +EXPOSE 8080 8081 +ENV METRICS_HOST=0.0.0.0:8081 WORKDIR /app +RUN apt-get update -y && apt-get install -y openssl ca-certificates && rm -rf /var/lib/apt/lists/* COPY --from=builder /bin/fuzzysearch /bin/fuzzysearch CMD ["/bin/fuzzysearch"] diff --git a/fuzzysearch/sqlx-data.json b/fuzzysearch/sqlx-data.json new file mode 100644 index 0000000..0b596c7 --- /dev/null +++ b/fuzzysearch/sqlx-data.json @@ -0,0 +1,200 @@ +{ + "db": "PostgreSQL", + "1984ce60f052d6a29638f8e05b35671b8edfbf273783d4b843ebd35cbb8a391f": { + "query": "INSERT INTO\n rate_limit (api_key_id, time_window, group_name, count)\n VALUES\n ($1, $2, $3, $4)\n ON CONFLICT ON CONSTRAINT unique_window\n DO UPDATE set count = rate_limit.count + $4\n RETURNING rate_limit.count", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "count", + "type_info": "Int2" + } + ], + "parameters": { + "Left": [ + "Int4", + "Int8", + "Text", + "Int2" + ] + }, + "nullable": [ + false + ] + } + }, + "1bd0057782de5a3b41f90081a31d24d14bb70299391050c3404742a6d2915d9e": { + "query": "SELECT\n hashes.id,\n hashes.hash,\n hashes.furaffinity_id,\n hashes.e621_id,\n hashes.twitter_id,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (f.url)\n WHEN e621_id IS NOT NULL THEN (e.data->'file'->>'url')\n WHEN twitter_id IS NOT NULL THEN (tm.url)\n END url,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (f.filename)\n WHEN e621_id IS NOT NULL THEN ((e.data->'file'->>'md5') || '.' || (e.data->'file'->>'ext'))\n WHEN twitter_id IS NOT NULL THEN (SELECT split_part(split_part(tm.url, '/', 5), ':', 1))\n END filename,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (ARRAY(SELECT f.name))\n WHEN e621_id IS NOT NULL THEN ARRAY(SELECT jsonb_array_elements_text(e.data->'tags'->'artist'))\n WHEN twitter_id IS NOT NULL THEN ARRAY(SELECT tw.data->'user'->>'screen_name')\n END artists,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (f.file_id)\n END file_id,\n CASE\n WHEN e621_id IS NOT NULL THEN ARRAY(SELECT jsonb_array_elements_text(e.data->'sources'))\n END sources,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (f.rating)\n WHEN e621_id IS NOT NULL THEN (e.data->>'rating')\n WHEN twitter_id IS NOT NULL THEN\n CASE\n WHEN (tw.data->'possibly_sensitive')::boolean IS true THEN 'adult'\n WHEN (tw.data->'possibly_sensitive')::boolean IS false THEN 'general'\n END\n END rating\n FROM\n hashes\n LEFT JOIN LATERAL (\n SELECT *\n FROM submission\n JOIN artist ON submission.artist_id = artist.id\n WHERE submission.id = hashes.furaffinity_id\n ) f ON hashes.furaffinity_id IS NOT NULL\n LEFT JOIN LATERAL (\n SELECT *\n FROM e621\n WHERE e621.id = hashes.e621_id\n ) e ON hashes.e621_id IS NOT NULL\n LEFT JOIN LATERAL (\n SELECT *\n FROM tweet\n WHERE tweet.id = hashes.twitter_id\n ) tw ON hashes.twitter_id IS NOT NULL\n LEFT JOIN LATERAL (\n SELECT *\n FROM tweet_media\n WHERE\n tweet_media.tweet_id = hashes.twitter_id AND\n tweet_media.hash <@ (hashes.hash, 0)\n LIMIT 1\n ) tm ON hashes.twitter_id IS NOT NULL\n WHERE hashes.id = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int4" + }, + { + "ordinal": 1, + "name": "hash", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "furaffinity_id", + "type_info": "Int4" + }, + { + "ordinal": 3, + "name": "e621_id", + "type_info": "Int4" + }, + { + "ordinal": 4, + "name": "twitter_id", + "type_info": "Int8" + }, + { + "ordinal": 5, + "name": "url", + "type_info": "Text" + }, + { + "ordinal": 6, + "name": "filename", + "type_info": "Text" + }, + { + "ordinal": 7, + "name": "artists", + "type_info": "TextArray" + }, + { + "ordinal": 8, + "name": "file_id", + "type_info": "Int4" + }, + { + "ordinal": 9, + "name": "sources", + "type_info": "TextArray" + }, + { + "ordinal": 10, + "name": "rating", + "type_info": "Bpchar" + } + ], + "parameters": { + "Left": [ + "Int4" + ] + }, + "nullable": [ + false, + false, + true, + true, + true, + null, + null, + null, + null, + null, + null + ] + } + }, + "659ee9ddc1c5ccd42ba9dc1617440544c30ece449ba3ba7f9d39f447b8af3cfe": { + "query": "SELECT\n api_key.id,\n api_key.name_limit,\n api_key.image_limit,\n api_key.hash_limit,\n api_key.name,\n account.email owner_email\n FROM\n api_key\n JOIN account\n ON account.id = api_key.user_id\n WHERE\n api_key.key = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int4" + }, + { + "ordinal": 1, + "name": "name_limit", + "type_info": "Int2" + }, + { + "ordinal": 2, + "name": "image_limit", + "type_info": "Int2" + }, + { + "ordinal": 3, + "name": "hash_limit", + "type_info": "Int2" + }, + { + "ordinal": 4, + "name": "name", + "type_info": "Varchar" + }, + { + "ordinal": 5, + "name": "owner_email", + "type_info": "Varchar" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + false, + false, + false, + true, + false + ] + } + }, + "6b8d304fc40fa539ae671e6e24e7978ad271cb7a1cafb20fc4b4096a958d790f": { + "query": "SELECT exists(SELECT 1 FROM twitter_user WHERE lower(data->>'screen_name') = lower($1))", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "exists", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + null + ] + } + }, + "fe60be66b2d8a8f02b3bfe06d1f0e57e4bb07e80cba1b379a5f17f6cbd8b075c": { + "query": "SELECT id, hash FROM hashes", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int4" + }, + { + "ordinal": 1, + "name": "hash", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + false, + false + ] + } + } +} \ No newline at end of file diff --git a/fuzzysearch/src/filters.rs b/fuzzysearch/src/filters.rs index bc9994d..d1974b4 100644 --- a/fuzzysearch/src/filters.rs +++ b/fuzzysearch/src/filters.rs @@ -10,10 +10,11 @@ pub fn search( ) -> impl Filter + Clone { search_image(db.clone(), tree.clone()) .or(search_hashes(db.clone(), tree.clone())) - .or(stream_search_image(db.clone(), tree)) + .or(stream_search_image(db.clone(), tree.clone())) .or(search_file(db.clone())) .or(search_video(db.clone())) - .or(check_handle(db)) + .or(check_handle(db.clone())) + .or(search_image_by_url(db, tree)) } pub fn search_file(db: Pool) -> impl Filter + Clone { @@ -55,6 +56,19 @@ pub fn search_image( }) } +pub fn search_image_by_url( + db: Pool, + tree: Tree, +) -> impl Filter + Clone { + warp::path("url") + .and(warp::get()) + .and(warp::query::()) + .and(with_pool(db)) + .and(with_tree(tree)) + .and(with_api_key()) + .and_then(handlers::search_image_by_url) +} + pub fn search_hashes( db: Pool, tree: Tree, diff --git a/fuzzysearch/src/handlers.rs b/fuzzysearch/src/handlers.rs index 729fe0d..7fd2bb6 100644 --- a/fuzzysearch/src/handlers.rs +++ b/fuzzysearch/src/handlers.rs @@ -1,46 +1,92 @@ -use crate::models::{image_query, image_query_sync}; -use crate::types::*; -use crate::{rate_limit, Pool, Tree}; +use lazy_static::lazy_static; +use prometheus::{register_histogram, register_int_counter, Histogram, IntCounter}; +use std::convert::TryInto; use tracing::{span, warn}; use tracing_futures::Instrument; -use warp::{reject, Rejection, Reply}; +use warp::{Rejection, Reply}; +use crate::models::{image_query, image_query_sync}; +use crate::types::*; +use crate::{early_return, rate_limit, Pool, Tree}; use fuzzysearch_common::types::{SearchResult, SiteInfo}; -fn map_bb8_err(err: bb8::RunError) -> Rejection { - reject::custom(Error::from(err)) -} - -fn map_postgres_err(err: tokio_postgres::Error) -> Rejection { - reject::custom(Error::from(err)) +lazy_static! { + static ref IMAGE_HASH_DURATION: Histogram = register_histogram!( + "fuzzysearch_api_image_hash_seconds", + "Duration to perform an image hash operation" + ) + .unwrap(); + static ref IMAGE_URL_DOWNLOAD_DURATION: Histogram = register_histogram!( + "fuzzysearch_api_image_url_download_seconds", + "Duration to download an image from a provided URL" + ) + .unwrap(); + static ref UNHANDLED_REJECTIONS: IntCounter = register_int_counter!( + "fuzzysearch_api_unhandled_rejections_count", + "Number of unhandled HTTP rejections" + ) + .unwrap(); } #[derive(Debug)] enum Error { - BB8(bb8::RunError), - Postgres(tokio_postgres::Error), + Postgres(sqlx::Error), + Reqwest(reqwest::Error), InvalidData, + InvalidImage, ApiKey, RateLimit, } -impl From> for Error { - fn from(err: bb8::RunError) -> Self { - Error::BB8(err) +impl warp::Reply for Error { + fn into_response(self) -> warp::reply::Response { + let msg = match self { + Error::Postgres(_) | Error::Reqwest(_) => ErrorMessage { + code: 500, + message: "Internal server error".to_string(), + }, + Error::InvalidData => ErrorMessage { + code: 400, + message: "Invalid data provided".to_string(), + }, + Error::InvalidImage => ErrorMessage { + code: 400, + message: "Invalid image provided".to_string(), + }, + Error::ApiKey => ErrorMessage { + code: 401, + message: "Invalid API key".to_string(), + }, + Error::RateLimit => ErrorMessage { + code: 429, + message: "Too many requests".to_string(), + }, + }; + + let body = hyper::body::Body::from(serde_json::to_string(&msg).unwrap()); + + warp::http::Response::builder() + .status(msg.code) + .body(body) + .unwrap() } } -impl From for Error { - fn from(err: tokio_postgres::Error) -> Self { +impl From for Error { + fn from(err: sqlx::Error) -> Self { Error::Postgres(err) } } -impl warp::reject::Reject for Error {} +impl From for Error { + fn from(err: reqwest::Error) -> Self { + Error::Reqwest(err) + } +} async fn get_field_bytes(form: warp::multipart::FormData, field: &str) -> bytes::BytesMut { use bytes::BufMut; - use futures_util::StreamExt; + use futures::StreamExt; let parts: Vec<_> = form.collect().await; let mut parts = parts @@ -66,6 +112,7 @@ async fn hash_input(form: warp::multipart::FormData) -> (i64, img_hash::ImageHas let len = bytes.len(); + let _timer = IMAGE_HASH_DURATION.start_timer(); let hash = tokio::task::spawn_blocking(move || { let hasher = fuzzysearch_common::get_hasher(); let image = image::load_from_memory(&bytes).unwrap(); @@ -74,6 +121,7 @@ async fn hash_input(form: warp::multipart::FormData) -> (i64, img_hash::ImageHas .instrument(span!(tracing::Level::TRACE, "hashing image", len)) .await .unwrap(); + drop(_timer); let mut buf: [u8; 8] = [0; 8]; buf.copy_from_slice(&hash.as_bytes()); @@ -83,7 +131,7 @@ async fn hash_input(form: warp::multipart::FormData) -> (i64, img_hash::ImageHas #[tracing::instrument(skip(form))] async fn hash_video(form: warp::multipart::FormData) -> Vec<[u8; 8]> { - use bytes::buf::BufExt; + use bytes::Buf; let bytes = get_field_bytes(form, "video").await; @@ -106,21 +154,19 @@ async fn hash_video(form: warp::multipart::FormData) -> Vec<[u8; 8]> { pub async fn search_image( form: warp::multipart::FormData, opts: ImageSearchOpts, - pool: Pool, + db: Pool, tree: Tree, api_key: String, -) -> Result { - let db = pool.get().await.map_err(map_bb8_err)?; - - rate_limit!(&api_key, &db, image_limit, "image"); - rate_limit!(&api_key, &db, hash_limit, "hash"); +) -> Result, Rejection> { + let image_remaining = rate_limit!(&api_key, &db, image_limit, "image"); + let hash_remaining = rate_limit!(&api_key, &db, hash_limit, "hash"); let (num, hash) = hash_input(form).await; let mut items = { if opts.search_type == Some(ImageSearchType::Force) { image_query( - pool.clone(), + db.clone(), tree.clone(), vec![num], 10, @@ -130,7 +176,7 @@ pub async fn search_image( .unwrap() } else { let results = image_query( - pool.clone(), + db.clone(), tree.clone(), vec![num], 0, @@ -140,7 +186,7 @@ pub async fn search_image( .unwrap(); if results.is_empty() && opts.search_type != Some(ImageSearchType::Exact) { image_query( - pool.clone(), + db.clone(), tree.clone(), vec![num], 10, @@ -166,7 +212,20 @@ pub async fn search_image( matches: items, }; - Ok(warp::reply::json(&similarity)) + let resp = warp::http::Response::builder() + .header("x-image-hash", num.to_string()) + .header("x-rate-limit-total-image", image_remaining.1.to_string()) + .header( + "x-rate-limit-remaining-image", + image_remaining.0.to_string(), + ) + .header("x-rate-limit-total-hash", hash_remaining.1.to_string()) + .header("x-rate-limit-remaining-hash", hash_remaining.0.to_string()) + .header("content-type", "application/json") + .body(serde_json::to_string(&similarity).unwrap()) + .unwrap(); + + Ok(Box::new(resp)) } pub async fn stream_image( @@ -174,34 +233,36 @@ pub async fn stream_image( pool: Pool, tree: Tree, api_key: String, -) -> Result { - use futures_util::StreamExt; - - let db = pool.get().await.map_err(map_bb8_err)?; - - rate_limit!(&api_key, &db, image_limit, "image", 2); - rate_limit!(&api_key, &db, hash_limit, "hash"); +) -> Result, Rejection> { + rate_limit!(&api_key, &pool, image_limit, "image", 2); + rate_limit!(&api_key, &pool, hash_limit, "hash"); let (num, hash) = hash_input(form).await; - let event_stream = image_query_sync( + let mut query = image_query_sync( pool.clone(), tree, vec![num], 10, Some(hash.as_bytes().to_vec()), - ) - .map(sse_matches); + ); - Ok(warp::sse::reply(event_stream)) + let event_stream = async_stream::stream! { + while let Some(result) = query.recv().await { + yield sse_matches(result); + } + }; + + Ok(Box::new(warp::sse::reply(event_stream))) } +#[allow(clippy::unnecessary_wraps)] fn sse_matches( - matches: Result, tokio_postgres::Error>, -) -> Result { + matches: Result, sqlx::Error>, +) -> Result { let items = matches.unwrap(); - Ok(warp::sse::json(items)) + Ok(warp::sse::Event::default().json_data(items).unwrap()) } pub async fn search_hashes( @@ -209,9 +270,8 @@ pub async fn search_hashes( db: Pool, tree: Tree, api_key: String, -) -> Result { +) -> Result, Rejection> { let pool = db.clone(); - let db = db.get().await.map_err(map_bb8_err)?; let hashes: Vec = opts .hashes @@ -221,10 +281,10 @@ pub async fn search_hashes( .collect(); if hashes.is_empty() { - return Err(warp::reject::custom(Error::InvalidData)); + return Ok(Box::new(Error::InvalidData)); } - rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16); + let image_remaining = rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16); let mut results = image_query_sync( pool, @@ -236,66 +296,107 @@ pub async fn search_hashes( let mut matches = Vec::new(); while let Some(r) = results.recv().await { - matches.extend(r.map_err(|e| warp::reject::custom(Error::Postgres(e)))?); + matches.extend(early_return!(r)); } - Ok(warp::reply::json(&matches)) + let resp = warp::http::Response::builder() + .header("x-rate-limit-total-image", image_remaining.1.to_string()) + .header( + "x-rate-limit-remaining-image", + image_remaining.0.to_string(), + ) + .header("content-type", "application/json") + .body(serde_json::to_string(&matches).unwrap()) + .unwrap(); + + Ok(Box::new(resp)) } pub async fn search_file( opts: FileSearchOpts, db: Pool, api_key: String, -) -> Result { - let db = db.get().await.map_err(map_bb8_err)?; +) -> Result, Rejection> { + use sqlx::Row; - rate_limit!(&api_key, &db, name_limit, "file"); + let file_remaining = rate_limit!(&api_key, &db, name_limit, "file"); - let (filter, val): (&'static str, &(dyn tokio_postgres::types::ToSql + Sync)) = - if let Some(ref id) = opts.id { - ("file_id = $1", id) - } else if let Some(ref name) = opts.name { - ("lower(filename) = lower($1)", name) - } else if let Some(ref url) = opts.url { - ("lower(url) = lower($1)", url) - } else { - return Err(warp::reject::custom(Error::InvalidData)); - }; + let query = if let Some(ref id) = opts.id { + sqlx::query( + "SELECT + submission.id, + submission.url, + submission.filename, + submission.file_id, + submission.rating, + artist.name, + hashes.id hash_id + FROM + submission + JOIN artist + ON artist.id = submission.artist_id + JOIN hashes + ON hashes.furaffinity_id = submission.id + WHERE + file_id = $1 + LIMIT 10", + ) + .bind(id) + } else if let Some(ref name) = opts.name { + sqlx::query( + "SELECT + submission.id, + submission.url, + submission.filename, + submission.file_id, + submission.rating, + artist.name, + hashes.id hash_id + FROM + submission + JOIN artist + ON artist.id = submission.artist_id + JOIN hashes + ON hashes.furaffinity_id = submission.id + WHERE + lower(filename) = lower($1) + LIMIT 10", + ) + .bind(name) + } else if let Some(ref url) = opts.url { + sqlx::query( + "SELECT + submission.id, + submission.url, + submission.filename, + submission.file_id, + submission.rating, + artist.name, + hashes.id hash_id + FROM + submission + JOIN artist + ON artist.id = submission.artist_id + JOIN hashes + ON hashes.furaffinity_id = submission.id + WHERE + lower(url) = lower($1) + LIMIT 10", + ) + .bind(url) + } else { + return Ok(Box::new(Error::InvalidData)); + }; - let query = format!( - "SELECT - submission.id, - submission.url, - submission.filename, - submission.file_id, - artist.name, - hashes.id hash_id - FROM - submission - JOIN artist - ON artist.id = submission.artist_id - JOIN hashes - ON hashes.furaffinity_id = submission.id - WHERE - {} - LIMIT 10", - filter - ); - - let matches: Vec<_> = db - .query::(&*query, &[val]) - .instrument(span!(tracing::Level::TRACE, "waiting for db")) - .await - .map_err(map_postgres_err)? - .into_iter() + let matches: Result, _> = query .map(|row| SearchResult { id: row.get("hash_id"), - site_id: row.get::<&str, i32>("id") as i64, - site_id_str: row.get::<&str, i32>("id").to_string(), + site_id: row.get::("id") as i64, + site_id_str: row.get::("id").to_string(), url: row.get("url"), filename: row.get("filename"), artists: row - .get::<&str, Option>("name") + .get::, _>("name") .map(|artist| vec![artist]), distance: None, hash: None, @@ -303,10 +404,23 @@ pub async fn search_file( file_id: row.get("file_id"), }), searched_hash: None, + rating: row + .get::, _>("rating") + .and_then(|rating| rating.parse().ok()), }) - .collect(); + .fetch_all(&db) + .await; - Ok(warp::reply::json(&matches)) + let matches = early_return!(matches); + + let resp = warp::http::Response::builder() + .header("x-rate-limit-total-file", file_remaining.1.to_string()) + .header("x-rate-limit-remaining-file", file_remaining.0.to_string()) + .header("content-type", "application/json") + .body(serde_json::to_string(&matches).unwrap()) + .unwrap(); + + Ok(Box::new(resp)) } pub async fn search_video( @@ -319,56 +433,116 @@ pub async fn search_video( Ok(warp::reply::json(&hashes)) } -pub async fn check_handle(opts: HandleOpts, db: Pool) -> Result { - let db = db.get().await.map_err(map_bb8_err)?; - +pub async fn check_handle(opts: HandleOpts, db: Pool) -> Result, Rejection> { let exists = if let Some(handle) = opts.twitter { - !db.query( - "SELECT 1 FROM twitter_user WHERE lower(data->>'screen_name') = lower($1)", - &[&handle], - ) - .await - .map_err(map_postgres_err)? - .is_empty() + let result = sqlx::query_scalar!("SELECT exists(SELECT 1 FROM twitter_user WHERE lower(data->>'screen_name') = lower($1))", handle) + .fetch_optional(&db) + .await + .map(|row| row.flatten().unwrap_or(false)); + + early_return!(result) } else { false }; - Ok(warp::reply::json(&exists)) + Ok(Box::new(warp::reply::json(&exists))) +} + +pub async fn search_image_by_url( + opts: UrlSearchOpts, + db: Pool, + tree: Tree, + api_key: String, +) -> Result, Rejection> { + use bytes::BufMut; + + let url = opts.url; + + let image_remaining = rate_limit!(&api_key, &db, image_limit, "image"); + let hash_remaining = rate_limit!(&api_key, &db, hash_limit, "hash"); + + let _timer = IMAGE_URL_DOWNLOAD_DURATION.start_timer(); + + let mut resp = match reqwest::get(&url).await { + Ok(resp) => resp, + Err(_err) => return Ok(Box::new(Error::InvalidImage)), + }; + + let content_length = resp + .headers() + .get("content-length") + .and_then(|len| { + String::from_utf8_lossy(len.as_bytes()) + .parse::() + .ok() + }) + .unwrap_or(0); + + if content_length > 10_000_000 { + return Ok(Box::new(Error::InvalidImage)); + } + + let mut buf = bytes::BytesMut::with_capacity(content_length); + + while let Some(chunk) = early_return!(resp.chunk().await) { + if buf.len() + chunk.len() > 10_000_000 { + return Ok(Box::new(Error::InvalidImage)); + } + + buf.put(chunk); + } + + drop(_timer); + + let _timer = IMAGE_HASH_DURATION.start_timer(); + let hash = tokio::task::spawn_blocking(move || { + let hasher = fuzzysearch_common::get_hasher(); + let image = image::load_from_memory(&buf).unwrap(); + hasher.hash_image(&image) + }) + .instrument(span!(tracing::Level::TRACE, "hashing image")) + .await + .unwrap(); + drop(_timer); + + let hash: [u8; 8] = hash.as_bytes().try_into().unwrap(); + let num = i64::from_be_bytes(hash); + + let results = image_query(db.clone(), tree.clone(), vec![num], 3, Some(hash.to_vec())) + .await + .unwrap(); + + let resp = warp::http::Response::builder() + .header("x-image-hash", num.to_string()) + .header("x-rate-limit-total-image", image_remaining.1.to_string()) + .header( + "x-rate-limit-remaining-image", + image_remaining.0.to_string(), + ) + .header("x-rate-limit-total-hash", hash_remaining.1.to_string()) + .header("x-rate-limit-remaining-hash", hash_remaining.0.to_string()) + .header("content-type", "application/json") + .body(serde_json::to_string(&results).unwrap()) + .unwrap(); + + Ok(Box::new(resp)) } #[tracing::instrument] -pub async fn handle_rejection(err: Rejection) -> Result { +pub async fn handle_rejection(err: Rejection) -> Result, std::convert::Infallible> { warn!("had rejection"); + UNHANDLED_REJECTIONS.inc(); + let (code, message) = if err.is_not_found() { ( warp::http::StatusCode::NOT_FOUND, "This page does not exist", ) - } else if let Some(err) = err.find::() { - match err { - Error::BB8(_inner) => ( - warp::http::StatusCode::INTERNAL_SERVER_ERROR, - "A database error occured", - ), - Error::Postgres(_inner) => ( - warp::http::StatusCode::INTERNAL_SERVER_ERROR, - "A database error occured", - ), - Error::InvalidData => ( - warp::http::StatusCode::BAD_REQUEST, - "Unable to operate on provided data", - ), - Error::ApiKey => ( - warp::http::StatusCode::UNAUTHORIZED, - "Invalid API key provided", - ), - Error::RateLimit => ( - warp::http::StatusCode::TOO_MANY_REQUESTS, - "Your API token is rate limited", - ), - } + } else if err.find::().is_some() { + return Ok(Box::new(Error::InvalidData) as Box); + } else if err.find::().is_some() { + return Ok(Box::new(Error::InvalidData) as Box); } else { ( warp::http::StatusCode::INTERNAL_SERVER_ERROR, @@ -381,5 +555,5 @@ pub async fn handle_rejection(err: Rejection) -> Result>>; +type Pool = sqlx::PgPool; + +#[derive(Debug)] +pub struct Node { + id: i32, + hash: [u8; 8], +} + +impl Node { + pub fn query(hash: [u8; 8]) -> Self { + Self { id: -1, hash } + } +} + +pub struct Hamming; + +impl bk_tree::Metric for Hamming { + fn distance(&self, a: &Node, b: &Node) -> u64 { + hamming::distance_fast(&a.hash, &b.hash).unwrap() + } +} + +#[tokio::main] +async fn main() { + configure_tracing(); + + ffmpeg_next::init().expect("Unable to initialize ffmpeg"); + + let s = std::env::var("DATABASE_URL").expect("Missing DATABASE_URL"); + + let db_pool = sqlx::PgPool::connect(&s) + .await + .expect("Unable to create Postgres pool"); + + serve_metrics().await; + + let tree: Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(Hamming))); + + load_updates(db_pool.clone(), tree.clone()).await; + + let log = warp::log("fuzzysearch"); + let cors = warp::cors() + .allow_any_origin() + .allow_headers(vec!["x-api-key"]) + .allow_methods(vec!["GET", "POST"]); + + let options = warp::options().map(|| "✓"); + + let api = options.or(filters::search(db_pool, tree)); + let routes = api + .or(warp::path::end() + .map(|| warp::redirect(warp::http::Uri::from_static("https://fuzzysearch.net")))) + .with(log) + .with(cors) + .recover(handlers::handle_rejection); + + warp::serve(routes).run(([0, 0, 0, 0], 8080)).await; +} fn configure_tracing() { use opentelemetry::{ @@ -65,133 +123,102 @@ fn configure_tracing() { registry.init(); } -#[derive(Debug)] -pub struct Node { +async fn metrics( + _: hyper::Request, +) -> Result, std::convert::Infallible> { + use hyper::{Body, Response}; + use prometheus::{Encoder, TextEncoder}; + + let mut buffer = Vec::new(); + let encoder = TextEncoder::new(); + + let metric_families = prometheus::gather(); + encoder.encode(&metric_families, &mut buffer).unwrap(); + + Ok(Response::new(Body::from(buffer))) +} + +async fn serve_metrics() { + use hyper::{ + service::{make_service_fn, service_fn}, + Server, + }; + use std::convert::Infallible; + use std::net::SocketAddr; + + let make_svc = make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn(metrics)) }); + + let addr: SocketAddr = std::env::var("METRICS_HOST") + .expect("Missing METRICS_HOST") + .parse() + .expect("Invalid METRICS_HOST"); + + let server = Server::bind(&addr).serve(make_svc); + + tokio::spawn(async move { + server.await.expect("Metrics server error"); + }); +} + +#[derive(serde::Deserialize)] +struct HashRow { id: i32, - hash: [u8; 8], + hash: i64, } -impl Node { - pub fn query(hash: [u8; 8]) -> Self { - Self { id: -1, hash } +async fn create_tree(conn: &Pool) -> bk_tree::BKTree { + use futures::TryStreamExt; + + let mut tree = bk_tree::BKTree::new(Hamming); + + let mut rows = sqlx::query_as!(HashRow, "SELECT id, hash FROM hashes").fetch(conn); + + while let Some(row) = rows.try_next().await.expect("Unable to get row") { + tree.add(Node { + id: row.id, + hash: row.hash.to_be_bytes(), + }) } + + tree } -type Tree = Arc>>; - -pub struct Hamming; - -impl bk_tree::Metric for Hamming { - fn distance(&self, a: &Node, b: &Node) -> u64 { - hamming::distance_fast(&a.hash, &b.hash).unwrap() - } -} - -#[tokio::main] -async fn main() { - ffmpeg_next::init().expect("Unable to initialize ffmpeg"); - - configure_tracing(); - - let s = std::env::var("POSTGRES_DSN").expect("Missing POSTGRES_DSN"); - - let manager = bb8_postgres::PostgresConnectionManager::new( - tokio_postgres::Config::from_str(&s).expect("Invalid POSTGRES_DSN"), - tokio_postgres::NoTls, - ); - - let db_pool = bb8::Pool::builder() - .build(manager) +async fn load_updates(conn: Pool, tree: Tree) { + let mut listener = sqlx::postgres::PgListener::connect_with(&conn) .await - .expect("Unable to build Postgres pool"); + .unwrap(); + listener.listen("fuzzysearch_hash_added").await.unwrap(); - let tree: Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(Hamming))); - - let mut max_id = 0; - - let conn = db_pool.get().await.unwrap(); + let new_tree = create_tree(&conn).await; let mut lock = tree.write().await; - conn.query("SELECT id, hash FROM hashes", &[]) - .await - .unwrap() - .into_iter() - .for_each(|row| { - let id: i32 = row.get(0); - let hash: i64 = row.get(1); - let bytes = hash.to_be_bytes(); + *lock = new_tree; + drop(lock); - if id > max_id { - max_id = id; + tokio::spawn(async move { + loop { + while let Some(notification) = listener + .try_recv() + .await + .expect("Unable to recv notification") + { + let payload: HashRow = serde_json::from_str(notification.payload()).unwrap(); + tracing::debug!(id = payload.id, "Adding new hash to tree"); + + let mut lock = tree.write().await; + lock.add(Node { + id: payload.id, + hash: payload.hash.to_be_bytes(), + }); + drop(lock); } - lock.add(Node { id, hash: bytes }); - }); - drop(lock); - drop(conn); - - let tree_clone = tree.clone(); - let pool_clone = db_pool.clone(); - tokio::spawn(async move { - use futures_util::StreamExt; - - let max_id = std::sync::atomic::AtomicI32::new(max_id); - let tree = tree_clone; - let pool = pool_clone; - - let order = std::sync::atomic::Ordering::SeqCst; - - let interval = tokio::time::interval(std::time::Duration::from_secs(30)); - - interval - .for_each(|_| async { - tracing::debug!("Refreshing hashes"); - - let conn = pool.get().await.unwrap(); - let mut lock = tree.write().await; - let id = max_id.load(order); - - let mut count = 0; - - conn.query("SELECT id, hash FROM hashes WHERE hashes.id > $1", &[&id]) - .await - .unwrap() - .into_iter() - .for_each(|row| { - let id: i32 = row.get(0); - let hash: i64 = row.get(1); - let bytes = hash.to_be_bytes(); - - if id > max_id.load(order) { - max_id.store(id, order); - } - - lock.add(Node { id, hash: bytes }); - - count += 1; - }); - - tracing::trace!("Added {} new hashes", count); - }) - .await; + tracing::error!("Lost connection to Postgres, recreating tree"); + tokio::time::sleep(std::time::Duration::from_secs(10)).await; + let new_tree = create_tree(&conn).await; + let mut lock = tree.write().await; + *lock = new_tree; + drop(lock); + tracing::info!("Replaced tree"); + } }); - - let log = warp::log("fuzzysearch"); - let cors = warp::cors() - .allow_any_origin() - .allow_headers(vec!["x-api-key"]) - .allow_methods(vec!["GET", "POST"]); - - let options = warp::options().map(|| "✓"); - - let api = options.or(filters::search(db_pool, tree)); - let routes = api - .or(warp::path::end() - .map(|| warp::redirect(warp::http::Uri::from_static("https://fuzzysearch.net")))) - .with(log) - .with(cors) - .recover(handlers::handle_rejection); - - warp::serve(routes).run(([0, 0, 0, 0], 8080)).await; } - -type Pool = bb8::Pool>; diff --git a/fuzzysearch/src/models.rs b/fuzzysearch/src/models.rs index ecefce7..dd943ef 100644 --- a/fuzzysearch/src/models.rs +++ b/fuzzysearch/src/models.rs @@ -1,46 +1,48 @@ -use crate::types::*; -use crate::utils::extract_rows; -use crate::{Pool, Tree}; +use lazy_static::lazy_static; +use prometheus::{register_histogram, Histogram}; use tracing_futures::Instrument; -use fuzzysearch_common::types::SearchResult; +use crate::types::*; +use crate::{Pool, Tree}; +use fuzzysearch_common::types::{SearchResult, SiteInfo}; -pub type DB<'a> = - &'a bb8::PooledConnection<'a, bb8_postgres::PostgresConnectionManager>; +lazy_static! { + static ref IMAGE_LOOKUP_DURATION: Histogram = register_histogram!( + "fuzzysearch_api_image_lookup_seconds", + "Duration to perform an image lookup" + ) + .unwrap(); + static ref IMAGE_QUERY_DURATION: Histogram = register_histogram!( + "fuzzysearch_api_image_query_seconds", + "Duration to perform a single image lookup query" + ) + .unwrap(); +} #[tracing::instrument(skip(db))] -pub async fn lookup_api_key(key: &str, db: DB<'_>) -> Option { - let rows = db - .query( - "SELECT +pub async fn lookup_api_key(key: &str, db: &sqlx::PgPool) -> Option { + sqlx::query_as!( + ApiKey, + "SELECT api_key.id, api_key.name_limit, api_key.image_limit, api_key.hash_limit, api_key.name, - account.email + account.email owner_email FROM api_key JOIN account ON account.id = api_key.user_id WHERE - api_key.key = $1", - &[&key], - ) - .await - .expect("Unable to query API keys"); - - match rows.into_iter().next() { - Some(row) => Some(ApiKey { - id: row.get(0), - name_limit: row.get(1), - image_limit: row.get(2), - hash_limit: row.get(3), - name: row.get(4), - owner_email: row.get(5), - }), - _ => None, - } + api_key.key = $1 + ", + key + ) + .fetch_optional(db) + .await + .ok() + .flatten() } #[tracing::instrument(skip(pool, tree))] @@ -50,7 +52,7 @@ pub async fn image_query( hashes: Vec, distance: i64, hash: Option>, -) -> Result, tokio_postgres::Error> { +) -> Result, sqlx::Error> { let mut results = image_query_sync(pool, tree, hashes, distance, hash); let mut matches = Vec::new(); @@ -68,19 +70,30 @@ pub fn image_query_sync( hashes: Vec, distance: i64, hash: Option>, -) -> tokio::sync::mpsc::Receiver, tokio_postgres::Error>> { - let (mut tx, rx) = tokio::sync::mpsc::channel(50); +) -> tokio::sync::mpsc::Receiver, sqlx::Error>> { + let (tx, rx) = tokio::sync::mpsc::channel(50); tokio::spawn(async move { - let db = pool.get().await.unwrap(); + let db = pool; for query_hash in hashes { + let mut seen = std::collections::HashSet::new(); + + let _timer = IMAGE_LOOKUP_DURATION.start_timer(); + let node = crate::Node::query(query_hash.to_be_bytes()); let lock = tree.read().await; let items = lock.find(&node, distance as u64); - for (_dist, item) in items { - let query = db.query("SELECT + for (dist, item) in items { + if seen.contains(&item.id) { + continue; + } + seen.insert(item.id); + + let _timer = IMAGE_QUERY_DURATION.start_timer(); + + let row = sqlx::query!("SELECT hashes.id, hashes.hash, hashes.furaffinity_id, @@ -106,7 +119,16 @@ pub fn image_query_sync( END file_id, CASE WHEN e621_id IS NOT NULL THEN ARRAY(SELECT jsonb_array_elements_text(e.data->'sources')) - END sources + END sources, + CASE + WHEN furaffinity_id IS NOT NULL THEN (f.rating) + WHEN e621_id IS NOT NULL THEN (e.data->>'rating') + WHEN twitter_id IS NOT NULL THEN + CASE + WHEN (tw.data->'possibly_sensitive')::boolean IS true THEN 'adult' + WHEN (tw.data->'possibly_sensitive')::boolean IS false THEN 'general' + END + END rating FROM hashes LEFT JOIN LATERAL ( @@ -133,14 +155,45 @@ pub fn image_query_sync( tweet_media.hash <@ (hashes.hash, 0) LIMIT 1 ) tm ON hashes.twitter_id IS NOT NULL - WHERE hashes.id = $1", &[&item.id]).await; - let rows = query.map(|rows| { - extract_rows(rows, hash.as_deref()).into_iter().map(|mut file| { - file.searched_hash = Some(query_hash); - file - }).collect() - }); - tx.send(rows).await.unwrap(); + WHERE hashes.id = $1", item.id).map(|row| { + let (site_id, site_info) = if let Some(fa_id) = row.furaffinity_id { + ( + fa_id as i64, + Some(SiteInfo::FurAffinity { + file_id: row.file_id.unwrap(), + }) + ) + } else if let Some(e621_id) = row.e621_id { + ( + e621_id as i64, + Some(SiteInfo::E621 { + sources: row.sources, + }) + ) + } else if let Some(twitter_id) = row.twitter_id { + (twitter_id, Some(SiteInfo::Twitter)) + } else { + (-1, None) + }; + + let file = SearchResult { + id: row.id, + site_id, + site_info, + rating: row.rating.and_then(|rating| rating.parse().ok()), + site_id_str: site_id.to_string(), + url: row.url.unwrap_or_default(), + hash: Some(row.hash), + distance: Some(dist), + artists: row.artists, + filename: row.filename.unwrap_or_default(), + searched_hash: Some(query_hash), + }; + + vec![file] + }).fetch_one(&db).await; + + tx.send(row).await.unwrap(); } } }.in_current_span()); diff --git a/fuzzysearch/src/types.rs b/fuzzysearch/src/types.rs index 4340be9..ef43cc1 100644 --- a/fuzzysearch/src/types.rs +++ b/fuzzysearch/src/types.rs @@ -10,7 +10,7 @@ use fuzzysearch_common::types::SearchResult; pub struct ApiKey { pub id: i32, pub name: Option, - pub owner_email: Option, + pub owner_email: String, pub name_limit: i16, pub image_limit: i16, pub hash_limit: i16, @@ -22,7 +22,7 @@ pub enum RateLimit { /// This key is limited, we should deny the request. Limited, /// This key is available, contains the number of requests made. - Available(i16), + Available((i16, i16)), } #[derive(Debug, Deserialize)] @@ -68,3 +68,8 @@ pub struct HashSearchOpts { pub struct HandleOpts { pub twitter: Option, } + +#[derive(Debug, Deserialize)] +pub struct UrlSearchOpts { + pub url: String, +} diff --git a/fuzzysearch/src/utils.rs b/fuzzysearch/src/utils.rs index 89a2911..dad259a 100644 --- a/fuzzysearch/src/utils.rs +++ b/fuzzysearch/src/utils.rs @@ -1,7 +1,15 @@ -use crate::models::DB; use crate::types::*; +use lazy_static::lazy_static; +use prometheus::{register_int_counter_vec, IntCounterVec}; -use fuzzysearch_common::types::{SearchResult, SiteInfo}; +lazy_static! { + pub static ref RATE_LIMIT_STATUS: IntCounterVec = register_int_counter_vec!( + "fuzzysearch_api_rate_limit_count", + "Number of allowed and rate limited requests", + &["status"] + ) + .unwrap(); +} #[macro_export] macro_rules! rate_limit { @@ -9,18 +17,48 @@ macro_rules! rate_limit { rate_limit!($api_key, $db, $limit, $group, 1) }; - ($api_key:expr, $db:expr, $limit:tt, $group:expr, $incr_by:expr) => { - let api_key = crate::models::lookup_api_key($api_key, $db) - .await - .ok_or_else(|| warp::reject::custom(Error::ApiKey))?; + ($api_key:expr, $db:expr, $limit:tt, $group:expr, $incr_by:expr) => {{ + let api_key = match crate::models::lookup_api_key($api_key, $db).await { + Some(api_key) => api_key, + None => return Ok(Box::new(Error::ApiKey)), + }; - let rate_limit = - crate::utils::update_rate_limit($db, api_key.id, api_key.$limit, $group, $incr_by) - .await - .map_err(crate::handlers::map_postgres_err)?; + let rate_limit = match crate::utils::update_rate_limit( + $db, + api_key.id, + api_key.$limit, + $group, + $incr_by, + ) + .await + { + Ok(rate_limit) => rate_limit, + Err(err) => return Ok(Box::new(Error::Postgres(err))), + }; - if rate_limit == crate::types::RateLimit::Limited { - return Err(warp::reject::custom(Error::RateLimit)); + match rate_limit { + crate::types::RateLimit::Limited => { + crate::utils::RATE_LIMIT_STATUS + .with_label_values(&["limited"]) + .inc(); + return Ok(Box::new(Error::RateLimit)); + } + crate::types::RateLimit::Available(count) => { + crate::utils::RATE_LIMIT_STATUS + .with_label_values(&["allowed"]) + .inc(); + count + } + } + }}; +} + +#[macro_export] +macro_rules! early_return { + ($val:expr) => { + match $val { + Ok(val) => val, + Err(err) => return Ok(Box::new(Error::from(err))), } }; } @@ -33,85 +71,38 @@ macro_rules! rate_limit { /// joined requests. #[tracing::instrument(skip(db))] pub async fn update_rate_limit( - db: DB<'_>, + db: &sqlx::PgPool, key_id: i32, key_group_limit: i16, group_name: &'static str, incr_by: i16, -) -> Result { +) -> Result { let now = chrono::Utc::now(); let timestamp = now.timestamp(); let time_window = timestamp - (timestamp % 60); - let rows = db - .query( - "INSERT INTO - rate_limit (api_key_id, time_window, group_name, count) - VALUES - ($1, $2, $3, $4) - ON CONFLICT ON CONSTRAINT unique_window - DO UPDATE set count = rate_limit.count + $4 - RETURNING rate_limit.count", - &[&key_id, &time_window, &group_name, &incr_by], - ) - .await?; - - let count: i16 = rows[0].get(0); + let count: i16 = sqlx::query_scalar!( + "INSERT INTO + rate_limit (api_key_id, time_window, group_name, count) + VALUES + ($1, $2, $3, $4) + ON CONFLICT ON CONSTRAINT unique_window + DO UPDATE set count = rate_limit.count + $4 + RETURNING rate_limit.count", + key_id, + time_window, + group_name, + incr_by + ) + .fetch_one(db) + .await?; if count > key_group_limit { Ok(RateLimit::Limited) } else { - Ok(RateLimit::Available(count)) + Ok(RateLimit::Available(( + key_group_limit - count, + key_group_limit, + ))) } } - -pub fn extract_rows<'a>( - rows: Vec, - hash: Option<&'a [u8]>, -) -> impl IntoIterator + 'a { - rows.into_iter().map(move |row| { - let dbhash: i64 = row.get("hash"); - let dbbytes = dbhash.to_be_bytes(); - - let (furaffinity_id, e621_id, twitter_id): (Option, Option, Option) = ( - row.get("furaffinity_id"), - row.get("e621_id"), - row.get("twitter_id"), - ); - - let (site_id, site_info) = if let Some(fa_id) = furaffinity_id { - ( - fa_id as i64, - Some(SiteInfo::FurAffinity { - file_id: row.get("file_id"), - }), - ) - } else if let Some(e6_id) = e621_id { - ( - e6_id as i64, - Some(SiteInfo::E621 { - sources: row.get("sources"), - }), - ) - } else if let Some(t_id) = twitter_id { - (t_id, Some(SiteInfo::Twitter)) - } else { - (-1, None) - }; - - SearchResult { - id: row.get("id"), - site_id, - site_info, - site_id_str: site_id.to_string(), - url: row.get("url"), - hash: Some(dbhash), - distance: hash - .map(|hash| hamming::distance_fast(&dbbytes, &hash).ok()) - .flatten(), - artists: row.get("artists"), - filename: row.get("filename"), - searched_hash: None, - } - }) -}