Merge branch 'master' into unify

This commit is contained in:
Syfaro 2021-02-20 18:05:04 -05:00
commit 8914227f23
13 changed files with 943 additions and 460 deletions

View File

@ -33,4 +33,8 @@ steps:
from_secret: sccache_s3_endpoint
SCCACHE_S3_USE_SSL: true
---
kind: signature
hmac: 665dab5e07086669c4b215ed86faa0e1e63c495b0bf020099fb1edd33757618b
...

125
Cargo.lock generated
View File

@ -61,6 +61,27 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b"
[[package]]
name = "async-stream"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3670df70cbc01729f901f94c887814b3c68db038aad1329a418bae178bc5295c"
dependencies = [
"async-stream-impl",
"futures-core",
]
[[package]]
name = "async-stream-impl"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3548b8efc9f8e8a5a0a2808c5bd8451a9031b9e5b879a79590304ae928b0a70"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "async-trait"
version = "0.1.42"
@ -105,30 +126,6 @@ version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd"
[[package]]
name = "bb8"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "374bba43fc924d90393ee7768e6f75d223a98307a488fe5bc34b66c3e96932a6"
dependencies = [
"async-trait",
"futures",
"tokio 0.2.25",
]
[[package]]
name = "bb8-postgres"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39a233af6ea3952e20d01863c87b4f6689b2f806249688b0908b5f02d4fa41ac"
dependencies = [
"async-trait",
"bb8",
"futures",
"tokio 0.2.25",
"tokio-postgres 0.5.5",
]
[[package]]
name = "bindgen"
version = "0.54.0"
@ -835,25 +832,27 @@ dependencies = [
name = "fuzzysearch"
version = "0.1.0"
dependencies = [
"anyhow",
"bb8",
"bb8-postgres",
"async-stream",
"bk-tree",
"bytes 0.5.6",
"bytes 1.0.1",
"chrono",
"ffmpeg-next",
"futures",
"futures-util",
"fuzzysearch-common",
"hamming",
"hyper 0.14.4",
"image",
"img_hash",
"infer",
"lazy_static",
"opentelemetry",
"opentelemetry-jaeger",
"prometheus 0.11.0",
"reqwest 0.11.1",
"serde",
"tokio 0.2.25",
"tokio-postgres 0.5.5",
"serde_json",
"sqlx 0.5.1",
"tokio 1.2.0",
"tracing",
"tracing-futures",
"tracing-opentelemetry",
@ -1338,11 +1337,11 @@ checksum = "8906512588cd815b8f759fd0ac11de2a84c985c0f792f70df611e9325c270c1f"
[[package]]
name = "input_buffer"
version = "0.3.1"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19a8a95243d5a0398cae618ec29477c6e3cb631152be5c19481f80bc71559754"
checksum = "f97967975f448f1a7ddb12b0bc41069d09ed6a1c161a92687e057325db35d413"
dependencies = [
"bytes 0.5.6",
"bytes 1.0.1",
]
[[package]]
@ -2530,7 +2529,7 @@ dependencies = [
"pin-project-lite 0.2.4",
"serde",
"serde_json",
"serde_urlencoded 0.7.0",
"serde_urlencoded",
"tokio 0.2.25",
"tokio-tls",
"url",
@ -2565,7 +2564,7 @@ dependencies = [
"pin-project-lite 0.2.4",
"serde",
"serde_json",
"serde_urlencoded 0.7.0",
"serde_urlencoded",
"tokio 1.2.0",
"tokio-native-tls",
"url",
@ -2824,18 +2823,6 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_urlencoded"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ec5d77e2d4c73717816afac02670d5c4f534ea95ed430442cad02e7a6e32c97"
dependencies = [
"dtoa",
"itoa",
"serde",
"url",
]
[[package]]
name = "serde_urlencoded"
version = "0.7.0"
@ -3516,14 +3503,14 @@ dependencies = [
[[package]]
name = "tokio-tungstenite"
version = "0.11.0"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d9e878ad426ca286e4dcae09cbd4e1973a7f8987d97570e2469703dd7f5720c"
checksum = "e1a5f475f1b9d077ea1017ecbc60890fda8e54942d680ca0b1d2b47cfa2d861b"
dependencies = [
"futures-util",
"log",
"pin-project 0.4.27",
"tokio 0.2.25",
"pin-project 1.0.5",
"tokio 1.2.0",
"tungstenite",
]
@ -3699,18 +3686,18 @@ checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
[[package]]
name = "tungstenite"
version = "0.11.1"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0308d80d86700c5878b9ef6321f020f29b1bb9d5ff3cab25e75e23f3a492a23"
checksum = "8ada8297e8d70872fa9a551d93250a9f407beb9f37ef86494eb20012a2ff7c24"
dependencies = [
"base64 0.12.3",
"base64 0.13.0",
"byteorder",
"bytes 0.5.6",
"bytes 1.0.1",
"http",
"httparse",
"input_buffer",
"log",
"rand 0.7.3",
"rand 0.8.3",
"sha-1 0.9.4",
"url",
"utf-8",
@ -3812,12 +3799,6 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "urlencoding"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9232eb53352b4442e40d7900465dfc534e8cb2dc8f18656fcb2ac16112b5593"
[[package]]
name = "utf-8"
version = "0.7.5"
@ -3848,30 +3829,32 @@ dependencies = [
[[package]]
name = "warp"
version = "0.2.5"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f41be6df54c97904af01aa23e613d4521eed7ab23537cede692d4058f6449407"
checksum = "3dafd0aac2818a94a34df0df1100a7356c493d8ede4393875fd0b5c51bb6bc80"
dependencies = [
"bytes 0.5.6",
"bytes 1.0.1",
"futures",
"headers",
"http",
"hyper 0.13.10",
"hyper 0.14.4",
"log",
"mime",
"mime_guess",
"multipart",
"pin-project 0.4.27",
"percent-encoding",
"pin-project 1.0.5",
"scoped-tls",
"serde",
"serde_json",
"serde_urlencoded 0.6.1",
"tokio 0.2.25",
"serde_urlencoded",
"tokio 1.2.0",
"tokio-stream",
"tokio-tungstenite",
"tokio-util 0.6.3",
"tower-service",
"tracing",
"tracing-futures",
"urlencoding",
]
[[package]]

View File

@ -1,5 +1,28 @@
use serde::Serialize;
#[derive(Debug, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum Rating {
General,
Mature,
Adult,
}
impl std::str::FromStr for Rating {
type Err = &'static str;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let rating = match s {
"g" | "s" | "general" => Self::General,
"m" | "q" | "mature" => Self::Mature,
"a" | "e" | "adult" => Self::Adult,
_ => return Err("unknown rating"),
};
Ok(rating)
}
}
/// A general type for every result in a search.
#[derive(Debug, Default, Serialize)]
pub struct SearchResult {
@ -11,6 +34,7 @@ pub struct SearchResult {
pub url: String,
pub filename: String,
pub artists: Option<Vec<String>>,
pub rating: Option<Rating>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(flatten)]

View File

@ -165,7 +165,7 @@ async fn main() {
tokio::spawn(async move {
if let Err(e) = connection.await {
panic!(e);
panic!("postgres connection error: {:?}", e);
}
});

View File

@ -9,29 +9,33 @@ tracing = "0.1"
tracing-subscriber = "0.2"
tracing-futures = "0.2"
prometheus = { version = "0.11", features = ["process"] }
lazy_static = "1"
opentelemetry = "0.6"
opentelemetry-jaeger = "0.5"
tracing-opentelemetry = "0.5"
tokio = { version = "0.2", features = ["full"] }
futures = "0.3"
futures-util = "0.3"
tokio = { version = "1", features = ["full"] }
async-stream = "0.3"
futures = "0.3"
anyhow = "1"
chrono = "0.4"
bytes = "0.5"
infer = { version = "0.3", default-features = false }
bytes = "1"
serde = { version = "1", features = ["derive"] }
warp = "0.2"
serde_json = "1"
tokio-postgres = "0.5"
bb8 = "0.4"
bb8-postgres = "0.4"
warp = "0.3"
reqwest = "0.11"
hyper = "0.14"
image = "0.23"
sqlx = { version = "0.5", features = ["runtime-tokio-native-tls", "postgres", "macros", "json", "offline"] }
infer = { version = "0.3", default-features = false }
ffmpeg-next = "4"
image = "0.23"
img_hash = "3"
hamming = "0.1"

View File

@ -1,10 +1,14 @@
FROM rust:1-slim AS builder
WORKDIR /src
ENV SQLX_OFFLINE=true
RUN apt-get update -y && apt-get install -y libssl-dev pkg-config
COPY . .
RUN cargo install --root / --path .
FROM debian:buster-slim
EXPOSE 8080
EXPOSE 8080 8081
ENV METRICS_HOST=0.0.0.0:8081
WORKDIR /app
RUN apt-get update -y && apt-get install -y openssl ca-certificates && rm -rf /var/lib/apt/lists/*
COPY --from=builder /bin/fuzzysearch /bin/fuzzysearch
CMD ["/bin/fuzzysearch"]

200
fuzzysearch/sqlx-data.json Normal file
View File

@ -0,0 +1,200 @@
{
"db": "PostgreSQL",
"1984ce60f052d6a29638f8e05b35671b8edfbf273783d4b843ebd35cbb8a391f": {
"query": "INSERT INTO\n rate_limit (api_key_id, time_window, group_name, count)\n VALUES\n ($1, $2, $3, $4)\n ON CONFLICT ON CONSTRAINT unique_window\n DO UPDATE set count = rate_limit.count + $4\n RETURNING rate_limit.count",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "count",
"type_info": "Int2"
}
],
"parameters": {
"Left": [
"Int4",
"Int8",
"Text",
"Int2"
]
},
"nullable": [
false
]
}
},
"1bd0057782de5a3b41f90081a31d24d14bb70299391050c3404742a6d2915d9e": {
"query": "SELECT\n hashes.id,\n hashes.hash,\n hashes.furaffinity_id,\n hashes.e621_id,\n hashes.twitter_id,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (f.url)\n WHEN e621_id IS NOT NULL THEN (e.data->'file'->>'url')\n WHEN twitter_id IS NOT NULL THEN (tm.url)\n END url,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (f.filename)\n WHEN e621_id IS NOT NULL THEN ((e.data->'file'->>'md5') || '.' || (e.data->'file'->>'ext'))\n WHEN twitter_id IS NOT NULL THEN (SELECT split_part(split_part(tm.url, '/', 5), ':', 1))\n END filename,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (ARRAY(SELECT f.name))\n WHEN e621_id IS NOT NULL THEN ARRAY(SELECT jsonb_array_elements_text(e.data->'tags'->'artist'))\n WHEN twitter_id IS NOT NULL THEN ARRAY(SELECT tw.data->'user'->>'screen_name')\n END artists,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (f.file_id)\n END file_id,\n CASE\n WHEN e621_id IS NOT NULL THEN ARRAY(SELECT jsonb_array_elements_text(e.data->'sources'))\n END sources,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (f.rating)\n WHEN e621_id IS NOT NULL THEN (e.data->>'rating')\n WHEN twitter_id IS NOT NULL THEN\n CASE\n WHEN (tw.data->'possibly_sensitive')::boolean IS true THEN 'adult'\n WHEN (tw.data->'possibly_sensitive')::boolean IS false THEN 'general'\n END\n END rating\n FROM\n hashes\n LEFT JOIN LATERAL (\n SELECT *\n FROM submission\n JOIN artist ON submission.artist_id = artist.id\n WHERE submission.id = hashes.furaffinity_id\n ) f ON hashes.furaffinity_id IS NOT NULL\n LEFT JOIN LATERAL (\n SELECT *\n FROM e621\n WHERE e621.id = hashes.e621_id\n ) e ON hashes.e621_id IS NOT NULL\n LEFT JOIN LATERAL (\n SELECT *\n FROM tweet\n WHERE tweet.id = hashes.twitter_id\n ) tw ON hashes.twitter_id IS NOT NULL\n LEFT JOIN LATERAL (\n SELECT *\n FROM tweet_media\n WHERE\n tweet_media.tweet_id = hashes.twitter_id AND\n tweet_media.hash <@ (hashes.hash, 0)\n LIMIT 1\n ) tm ON hashes.twitter_id IS NOT NULL\n WHERE hashes.id = $1",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int4"
},
{
"ordinal": 1,
"name": "hash",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "furaffinity_id",
"type_info": "Int4"
},
{
"ordinal": 3,
"name": "e621_id",
"type_info": "Int4"
},
{
"ordinal": 4,
"name": "twitter_id",
"type_info": "Int8"
},
{
"ordinal": 5,
"name": "url",
"type_info": "Text"
},
{
"ordinal": 6,
"name": "filename",
"type_info": "Text"
},
{
"ordinal": 7,
"name": "artists",
"type_info": "TextArray"
},
{
"ordinal": 8,
"name": "file_id",
"type_info": "Int4"
},
{
"ordinal": 9,
"name": "sources",
"type_info": "TextArray"
},
{
"ordinal": 10,
"name": "rating",
"type_info": "Bpchar"
}
],
"parameters": {
"Left": [
"Int4"
]
},
"nullable": [
false,
false,
true,
true,
true,
null,
null,
null,
null,
null,
null
]
}
},
"659ee9ddc1c5ccd42ba9dc1617440544c30ece449ba3ba7f9d39f447b8af3cfe": {
"query": "SELECT\n api_key.id,\n api_key.name_limit,\n api_key.image_limit,\n api_key.hash_limit,\n api_key.name,\n account.email owner_email\n FROM\n api_key\n JOIN account\n ON account.id = api_key.user_id\n WHERE\n api_key.key = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int4"
},
{
"ordinal": 1,
"name": "name_limit",
"type_info": "Int2"
},
{
"ordinal": 2,
"name": "image_limit",
"type_info": "Int2"
},
{
"ordinal": 3,
"name": "hash_limit",
"type_info": "Int2"
},
{
"ordinal": 4,
"name": "name",
"type_info": "Varchar"
},
{
"ordinal": 5,
"name": "owner_email",
"type_info": "Varchar"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false,
false,
false,
true,
false
]
}
},
"6b8d304fc40fa539ae671e6e24e7978ad271cb7a1cafb20fc4b4096a958d790f": {
"query": "SELECT exists(SELECT 1 FROM twitter_user WHERE lower(data->>'screen_name') = lower($1))",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "exists",
"type_info": "Bool"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
null
]
}
},
"fe60be66b2d8a8f02b3bfe06d1f0e57e4bb07e80cba1b379a5f17f6cbd8b075c": {
"query": "SELECT id, hash FROM hashes",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int4"
},
{
"ordinal": 1,
"name": "hash",
"type_info": "Int8"
}
],
"parameters": {
"Left": []
},
"nullable": [
false,
false
]
}
}
}

View File

@ -10,10 +10,11 @@ pub fn search(
) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
search_image(db.clone(), tree.clone())
.or(search_hashes(db.clone(), tree.clone()))
.or(stream_search_image(db.clone(), tree))
.or(stream_search_image(db.clone(), tree.clone()))
.or(search_file(db.clone()))
.or(search_video(db.clone()))
.or(check_handle(db))
.or(check_handle(db.clone()))
.or(search_image_by_url(db, tree))
}
pub fn search_file(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
@ -55,6 +56,19 @@ pub fn search_image(
})
}
pub fn search_image_by_url(
db: Pool,
tree: Tree,
) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
warp::path("url")
.and(warp::get())
.and(warp::query::<UrlSearchOpts>())
.and(with_pool(db))
.and(with_tree(tree))
.and(with_api_key())
.and_then(handlers::search_image_by_url)
}
pub fn search_hashes(
db: Pool,
tree: Tree,

View File

@ -1,46 +1,92 @@
use crate::models::{image_query, image_query_sync};
use crate::types::*;
use crate::{rate_limit, Pool, Tree};
use lazy_static::lazy_static;
use prometheus::{register_histogram, register_int_counter, Histogram, IntCounter};
use std::convert::TryInto;
use tracing::{span, warn};
use tracing_futures::Instrument;
use warp::{reject, Rejection, Reply};
use warp::{Rejection, Reply};
use crate::models::{image_query, image_query_sync};
use crate::types::*;
use crate::{early_return, rate_limit, Pool, Tree};
use fuzzysearch_common::types::{SearchResult, SiteInfo};
fn map_bb8_err(err: bb8::RunError<tokio_postgres::Error>) -> Rejection {
reject::custom(Error::from(err))
}
fn map_postgres_err(err: tokio_postgres::Error) -> Rejection {
reject::custom(Error::from(err))
lazy_static! {
static ref IMAGE_HASH_DURATION: Histogram = register_histogram!(
"fuzzysearch_api_image_hash_seconds",
"Duration to perform an image hash operation"
)
.unwrap();
static ref IMAGE_URL_DOWNLOAD_DURATION: Histogram = register_histogram!(
"fuzzysearch_api_image_url_download_seconds",
"Duration to download an image from a provided URL"
)
.unwrap();
static ref UNHANDLED_REJECTIONS: IntCounter = register_int_counter!(
"fuzzysearch_api_unhandled_rejections_count",
"Number of unhandled HTTP rejections"
)
.unwrap();
}
#[derive(Debug)]
enum Error {
BB8(bb8::RunError<tokio_postgres::Error>),
Postgres(tokio_postgres::Error),
Postgres(sqlx::Error),
Reqwest(reqwest::Error),
InvalidData,
InvalidImage,
ApiKey,
RateLimit,
}
impl From<bb8::RunError<tokio_postgres::Error>> for Error {
fn from(err: bb8::RunError<tokio_postgres::Error>) -> Self {
Error::BB8(err)
impl warp::Reply for Error {
fn into_response(self) -> warp::reply::Response {
let msg = match self {
Error::Postgres(_) | Error::Reqwest(_) => ErrorMessage {
code: 500,
message: "Internal server error".to_string(),
},
Error::InvalidData => ErrorMessage {
code: 400,
message: "Invalid data provided".to_string(),
},
Error::InvalidImage => ErrorMessage {
code: 400,
message: "Invalid image provided".to_string(),
},
Error::ApiKey => ErrorMessage {
code: 401,
message: "Invalid API key".to_string(),
},
Error::RateLimit => ErrorMessage {
code: 429,
message: "Too many requests".to_string(),
},
};
let body = hyper::body::Body::from(serde_json::to_string(&msg).unwrap());
warp::http::Response::builder()
.status(msg.code)
.body(body)
.unwrap()
}
}
impl From<tokio_postgres::Error> for Error {
fn from(err: tokio_postgres::Error) -> Self {
impl From<sqlx::Error> for Error {
fn from(err: sqlx::Error) -> Self {
Error::Postgres(err)
}
}
impl warp::reject::Reject for Error {}
impl From<reqwest::Error> for Error {
fn from(err: reqwest::Error) -> Self {
Error::Reqwest(err)
}
}
async fn get_field_bytes(form: warp::multipart::FormData, field: &str) -> bytes::BytesMut {
use bytes::BufMut;
use futures_util::StreamExt;
use futures::StreamExt;
let parts: Vec<_> = form.collect().await;
let mut parts = parts
@ -66,6 +112,7 @@ async fn hash_input(form: warp::multipart::FormData) -> (i64, img_hash::ImageHas
let len = bytes.len();
let _timer = IMAGE_HASH_DURATION.start_timer();
let hash = tokio::task::spawn_blocking(move || {
let hasher = fuzzysearch_common::get_hasher();
let image = image::load_from_memory(&bytes).unwrap();
@ -74,6 +121,7 @@ async fn hash_input(form: warp::multipart::FormData) -> (i64, img_hash::ImageHas
.instrument(span!(tracing::Level::TRACE, "hashing image", len))
.await
.unwrap();
drop(_timer);
let mut buf: [u8; 8] = [0; 8];
buf.copy_from_slice(&hash.as_bytes());
@ -83,7 +131,7 @@ async fn hash_input(form: warp::multipart::FormData) -> (i64, img_hash::ImageHas
#[tracing::instrument(skip(form))]
async fn hash_video(form: warp::multipart::FormData) -> Vec<[u8; 8]> {
use bytes::buf::BufExt;
use bytes::Buf;
let bytes = get_field_bytes(form, "video").await;
@ -106,21 +154,19 @@ async fn hash_video(form: warp::multipart::FormData) -> Vec<[u8; 8]> {
pub async fn search_image(
form: warp::multipart::FormData,
opts: ImageSearchOpts,
pool: Pool,
db: Pool,
tree: Tree,
api_key: String,
) -> Result<impl Reply, Rejection> {
let db = pool.get().await.map_err(map_bb8_err)?;
rate_limit!(&api_key, &db, image_limit, "image");
rate_limit!(&api_key, &db, hash_limit, "hash");
) -> Result<Box<dyn Reply>, Rejection> {
let image_remaining = rate_limit!(&api_key, &db, image_limit, "image");
let hash_remaining = rate_limit!(&api_key, &db, hash_limit, "hash");
let (num, hash) = hash_input(form).await;
let mut items = {
if opts.search_type == Some(ImageSearchType::Force) {
image_query(
pool.clone(),
db.clone(),
tree.clone(),
vec![num],
10,
@ -130,7 +176,7 @@ pub async fn search_image(
.unwrap()
} else {
let results = image_query(
pool.clone(),
db.clone(),
tree.clone(),
vec![num],
0,
@ -140,7 +186,7 @@ pub async fn search_image(
.unwrap();
if results.is_empty() && opts.search_type != Some(ImageSearchType::Exact) {
image_query(
pool.clone(),
db.clone(),
tree.clone(),
vec![num],
10,
@ -166,7 +212,20 @@ pub async fn search_image(
matches: items,
};
Ok(warp::reply::json(&similarity))
let resp = warp::http::Response::builder()
.header("x-image-hash", num.to_string())
.header("x-rate-limit-total-image", image_remaining.1.to_string())
.header(
"x-rate-limit-remaining-image",
image_remaining.0.to_string(),
)
.header("x-rate-limit-total-hash", hash_remaining.1.to_string())
.header("x-rate-limit-remaining-hash", hash_remaining.0.to_string())
.header("content-type", "application/json")
.body(serde_json::to_string(&similarity).unwrap())
.unwrap();
Ok(Box::new(resp))
}
pub async fn stream_image(
@ -174,34 +233,36 @@ pub async fn stream_image(
pool: Pool,
tree: Tree,
api_key: String,
) -> Result<impl Reply, Rejection> {
use futures_util::StreamExt;
let db = pool.get().await.map_err(map_bb8_err)?;
rate_limit!(&api_key, &db, image_limit, "image", 2);
rate_limit!(&api_key, &db, hash_limit, "hash");
) -> Result<Box<dyn Reply>, Rejection> {
rate_limit!(&api_key, &pool, image_limit, "image", 2);
rate_limit!(&api_key, &pool, hash_limit, "hash");
let (num, hash) = hash_input(form).await;
let event_stream = image_query_sync(
let mut query = image_query_sync(
pool.clone(),
tree,
vec![num],
10,
Some(hash.as_bytes().to_vec()),
)
.map(sse_matches);
);
Ok(warp::sse::reply(event_stream))
let event_stream = async_stream::stream! {
while let Some(result) = query.recv().await {
yield sse_matches(result);
}
};
Ok(Box::new(warp::sse::reply(event_stream)))
}
#[allow(clippy::unnecessary_wraps)]
fn sse_matches(
matches: Result<Vec<SearchResult>, tokio_postgres::Error>,
) -> Result<impl warp::sse::ServerSentEvent, core::convert::Infallible> {
matches: Result<Vec<SearchResult>, sqlx::Error>,
) -> Result<warp::sse::Event, core::convert::Infallible> {
let items = matches.unwrap();
Ok(warp::sse::json(items))
Ok(warp::sse::Event::default().json_data(items).unwrap())
}
pub async fn search_hashes(
@ -209,9 +270,8 @@ pub async fn search_hashes(
db: Pool,
tree: Tree,
api_key: String,
) -> Result<impl Reply, Rejection> {
) -> Result<Box<dyn Reply>, Rejection> {
let pool = db.clone();
let db = db.get().await.map_err(map_bb8_err)?;
let hashes: Vec<i64> = opts
.hashes
@ -221,10 +281,10 @@ pub async fn search_hashes(
.collect();
if hashes.is_empty() {
return Err(warp::reject::custom(Error::InvalidData));
return Ok(Box::new(Error::InvalidData));
}
rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16);
let image_remaining = rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16);
let mut results = image_query_sync(
pool,
@ -236,38 +296,39 @@ pub async fn search_hashes(
let mut matches = Vec::new();
while let Some(r) = results.recv().await {
matches.extend(r.map_err(|e| warp::reject::custom(Error::Postgres(e)))?);
matches.extend(early_return!(r));
}
Ok(warp::reply::json(&matches))
let resp = warp::http::Response::builder()
.header("x-rate-limit-total-image", image_remaining.1.to_string())
.header(
"x-rate-limit-remaining-image",
image_remaining.0.to_string(),
)
.header("content-type", "application/json")
.body(serde_json::to_string(&matches).unwrap())
.unwrap();
Ok(Box::new(resp))
}
pub async fn search_file(
opts: FileSearchOpts,
db: Pool,
api_key: String,
) -> Result<impl Reply, Rejection> {
let db = db.get().await.map_err(map_bb8_err)?;
) -> Result<Box<dyn Reply>, Rejection> {
use sqlx::Row;
rate_limit!(&api_key, &db, name_limit, "file");
let file_remaining = rate_limit!(&api_key, &db, name_limit, "file");
let (filter, val): (&'static str, &(dyn tokio_postgres::types::ToSql + Sync)) =
if let Some(ref id) = opts.id {
("file_id = $1", id)
} else if let Some(ref name) = opts.name {
("lower(filename) = lower($1)", name)
} else if let Some(ref url) = opts.url {
("lower(url) = lower($1)", url)
} else {
return Err(warp::reject::custom(Error::InvalidData));
};
let query = format!(
let query = if let Some(ref id) = opts.id {
sqlx::query(
"SELECT
submission.id,
submission.url,
submission.filename,
submission.file_id,
submission.rating,
artist.name,
hashes.id hash_id
FROM
@ -277,25 +338,65 @@ pub async fn search_file(
JOIN hashes
ON hashes.furaffinity_id = submission.id
WHERE
{}
file_id = $1
LIMIT 10",
filter
);
)
.bind(id)
} else if let Some(ref name) = opts.name {
sqlx::query(
"SELECT
submission.id,
submission.url,
submission.filename,
submission.file_id,
submission.rating,
artist.name,
hashes.id hash_id
FROM
submission
JOIN artist
ON artist.id = submission.artist_id
JOIN hashes
ON hashes.furaffinity_id = submission.id
WHERE
lower(filename) = lower($1)
LIMIT 10",
)
.bind(name)
} else if let Some(ref url) = opts.url {
sqlx::query(
"SELECT
submission.id,
submission.url,
submission.filename,
submission.file_id,
submission.rating,
artist.name,
hashes.id hash_id
FROM
submission
JOIN artist
ON artist.id = submission.artist_id
JOIN hashes
ON hashes.furaffinity_id = submission.id
WHERE
lower(url) = lower($1)
LIMIT 10",
)
.bind(url)
} else {
return Ok(Box::new(Error::InvalidData));
};
let matches: Vec<_> = db
.query::<str>(&*query, &[val])
.instrument(span!(tracing::Level::TRACE, "waiting for db"))
.await
.map_err(map_postgres_err)?
.into_iter()
let matches: Result<Vec<SearchResult>, _> = query
.map(|row| SearchResult {
id: row.get("hash_id"),
site_id: row.get::<&str, i32>("id") as i64,
site_id_str: row.get::<&str, i32>("id").to_string(),
site_id: row.get::<i32, _>("id") as i64,
site_id_str: row.get::<i32, _>("id").to_string(),
url: row.get("url"),
filename: row.get("filename"),
artists: row
.get::<&str, Option<String>>("name")
.get::<Option<String>, _>("name")
.map(|artist| vec![artist]),
distance: None,
hash: None,
@ -303,10 +404,23 @@ pub async fn search_file(
file_id: row.get("file_id"),
}),
searched_hash: None,
rating: row
.get::<Option<String>, _>("rating")
.and_then(|rating| rating.parse().ok()),
})
.collect();
.fetch_all(&db)
.await;
Ok(warp::reply::json(&matches))
let matches = early_return!(matches);
let resp = warp::http::Response::builder()
.header("x-rate-limit-total-file", file_remaining.1.to_string())
.header("x-rate-limit-remaining-file", file_remaining.0.to_string())
.header("content-type", "application/json")
.body(serde_json::to_string(&matches).unwrap())
.unwrap();
Ok(Box::new(resp))
}
pub async fn search_video(
@ -319,56 +433,116 @@ pub async fn search_video(
Ok(warp::reply::json(&hashes))
}
pub async fn check_handle(opts: HandleOpts, db: Pool) -> Result<impl Reply, Rejection> {
let db = db.get().await.map_err(map_bb8_err)?;
pub async fn check_handle(opts: HandleOpts, db: Pool) -> Result<Box<dyn Reply>, Rejection> {
let exists = if let Some(handle) = opts.twitter {
!db.query(
"SELECT 1 FROM twitter_user WHERE lower(data->>'screen_name') = lower($1)",
&[&handle],
)
let result = sqlx::query_scalar!("SELECT exists(SELECT 1 FROM twitter_user WHERE lower(data->>'screen_name') = lower($1))", handle)
.fetch_optional(&db)
.await
.map_err(map_postgres_err)?
.is_empty()
.map(|row| row.flatten().unwrap_or(false));
early_return!(result)
} else {
false
};
Ok(warp::reply::json(&exists))
Ok(Box::new(warp::reply::json(&exists)))
}
pub async fn search_image_by_url(
opts: UrlSearchOpts,
db: Pool,
tree: Tree,
api_key: String,
) -> Result<Box<dyn Reply>, Rejection> {
use bytes::BufMut;
let url = opts.url;
let image_remaining = rate_limit!(&api_key, &db, image_limit, "image");
let hash_remaining = rate_limit!(&api_key, &db, hash_limit, "hash");
let _timer = IMAGE_URL_DOWNLOAD_DURATION.start_timer();
let mut resp = match reqwest::get(&url).await {
Ok(resp) => resp,
Err(_err) => return Ok(Box::new(Error::InvalidImage)),
};
let content_length = resp
.headers()
.get("content-length")
.and_then(|len| {
String::from_utf8_lossy(len.as_bytes())
.parse::<usize>()
.ok()
})
.unwrap_or(0);
if content_length > 10_000_000 {
return Ok(Box::new(Error::InvalidImage));
}
let mut buf = bytes::BytesMut::with_capacity(content_length);
while let Some(chunk) = early_return!(resp.chunk().await) {
if buf.len() + chunk.len() > 10_000_000 {
return Ok(Box::new(Error::InvalidImage));
}
buf.put(chunk);
}
drop(_timer);
let _timer = IMAGE_HASH_DURATION.start_timer();
let hash = tokio::task::spawn_blocking(move || {
let hasher = fuzzysearch_common::get_hasher();
let image = image::load_from_memory(&buf).unwrap();
hasher.hash_image(&image)
})
.instrument(span!(tracing::Level::TRACE, "hashing image"))
.await
.unwrap();
drop(_timer);
let hash: [u8; 8] = hash.as_bytes().try_into().unwrap();
let num = i64::from_be_bytes(hash);
let results = image_query(db.clone(), tree.clone(), vec![num], 3, Some(hash.to_vec()))
.await
.unwrap();
let resp = warp::http::Response::builder()
.header("x-image-hash", num.to_string())
.header("x-rate-limit-total-image", image_remaining.1.to_string())
.header(
"x-rate-limit-remaining-image",
image_remaining.0.to_string(),
)
.header("x-rate-limit-total-hash", hash_remaining.1.to_string())
.header("x-rate-limit-remaining-hash", hash_remaining.0.to_string())
.header("content-type", "application/json")
.body(serde_json::to_string(&results).unwrap())
.unwrap();
Ok(Box::new(resp))
}
#[tracing::instrument]
pub async fn handle_rejection(err: Rejection) -> Result<impl Reply, std::convert::Infallible> {
pub async fn handle_rejection(err: Rejection) -> Result<Box<dyn Reply>, std::convert::Infallible> {
warn!("had rejection");
UNHANDLED_REJECTIONS.inc();
let (code, message) = if err.is_not_found() {
(
warp::http::StatusCode::NOT_FOUND,
"This page does not exist",
)
} else if let Some(err) = err.find::<Error>() {
match err {
Error::BB8(_inner) => (
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
"A database error occured",
),
Error::Postgres(_inner) => (
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
"A database error occured",
),
Error::InvalidData => (
warp::http::StatusCode::BAD_REQUEST,
"Unable to operate on provided data",
),
Error::ApiKey => (
warp::http::StatusCode::UNAUTHORIZED,
"Invalid API key provided",
),
Error::RateLimit => (
warp::http::StatusCode::TOO_MANY_REQUESTS,
"Your API token is rate limited",
),
}
} else if err.find::<warp::reject::InvalidQuery>().is_some() {
return Ok(Box::new(Error::InvalidData) as Box<dyn Reply>);
} else if err.find::<warp::reject::MethodNotAllowed>().is_some() {
return Ok(Box::new(Error::InvalidData) as Box<dyn Reply>);
} else {
(
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
@ -381,5 +555,5 @@ pub async fn handle_rejection(err: Rejection) -> Result<impl Reply, std::convert
message: message.into(),
});
Ok(warp::reply::with_status(json, code))
Ok(Box::new(warp::reply::with_status(json, code)))
}

View File

@ -1,8 +1,8 @@
#![recursion_limit = "256"]
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::RwLock;
use warp::Filter;
mod filters;
mod handlers;
@ -10,7 +10,65 @@ mod models;
mod types;
mod utils;
use warp::Filter;
type Tree = Arc<RwLock<bk_tree::BKTree<Node, Hamming>>>;
type Pool = sqlx::PgPool;
#[derive(Debug)]
pub struct Node {
id: i32,
hash: [u8; 8],
}
impl Node {
pub fn query(hash: [u8; 8]) -> Self {
Self { id: -1, hash }
}
}
pub struct Hamming;
impl bk_tree::Metric<Node> for Hamming {
fn distance(&self, a: &Node, b: &Node) -> u64 {
hamming::distance_fast(&a.hash, &b.hash).unwrap()
}
}
#[tokio::main]
async fn main() {
configure_tracing();
ffmpeg_next::init().expect("Unable to initialize ffmpeg");
let s = std::env::var("DATABASE_URL").expect("Missing DATABASE_URL");
let db_pool = sqlx::PgPool::connect(&s)
.await
.expect("Unable to create Postgres pool");
serve_metrics().await;
let tree: Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(Hamming)));
load_updates(db_pool.clone(), tree.clone()).await;
let log = warp::log("fuzzysearch");
let cors = warp::cors()
.allow_any_origin()
.allow_headers(vec!["x-api-key"])
.allow_methods(vec!["GET", "POST"]);
let options = warp::options().map(|| "");
let api = options.or(filters::search(db_pool, tree));
let routes = api
.or(warp::path::end()
.map(|| warp::redirect(warp::http::Uri::from_static("https://fuzzysearch.net"))))
.with(log)
.with(cors)
.recover(handlers::handle_rejection);
warp::serve(routes).run(([0, 0, 0, 0], 8080)).await;
}
fn configure_tracing() {
use opentelemetry::{
@ -65,133 +123,102 @@ fn configure_tracing() {
registry.init();
}
#[derive(Debug)]
pub struct Node {
async fn metrics(
_: hyper::Request<hyper::Body>,
) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
use hyper::{Body, Response};
use prometheus::{Encoder, TextEncoder};
let mut buffer = Vec::new();
let encoder = TextEncoder::new();
let metric_families = prometheus::gather();
encoder.encode(&metric_families, &mut buffer).unwrap();
Ok(Response::new(Body::from(buffer)))
}
async fn serve_metrics() {
use hyper::{
service::{make_service_fn, service_fn},
Server,
};
use std::convert::Infallible;
use std::net::SocketAddr;
let make_svc = make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn(metrics)) });
let addr: SocketAddr = std::env::var("METRICS_HOST")
.expect("Missing METRICS_HOST")
.parse()
.expect("Invalid METRICS_HOST");
let server = Server::bind(&addr).serve(make_svc);
tokio::spawn(async move {
server.await.expect("Metrics server error");
});
}
#[derive(serde::Deserialize)]
struct HashRow {
id: i32,
hash: [u8; 8],
hash: i64,
}
impl Node {
pub fn query(hash: [u8; 8]) -> Self {
Self { id: -1, hash }
async fn create_tree(conn: &Pool) -> bk_tree::BKTree<Node, Hamming> {
use futures::TryStreamExt;
let mut tree = bk_tree::BKTree::new(Hamming);
let mut rows = sqlx::query_as!(HashRow, "SELECT id, hash FROM hashes").fetch(conn);
while let Some(row) = rows.try_next().await.expect("Unable to get row") {
tree.add(Node {
id: row.id,
hash: row.hash.to_be_bytes(),
})
}
tree
}
type Tree = Arc<RwLock<bk_tree::BKTree<Node, Hamming>>>;
pub struct Hamming;
impl bk_tree::Metric<Node> for Hamming {
fn distance(&self, a: &Node, b: &Node) -> u64 {
hamming::distance_fast(&a.hash, &b.hash).unwrap()
}
}
#[tokio::main]
async fn main() {
ffmpeg_next::init().expect("Unable to initialize ffmpeg");
configure_tracing();
let s = std::env::var("POSTGRES_DSN").expect("Missing POSTGRES_DSN");
let manager = bb8_postgres::PostgresConnectionManager::new(
tokio_postgres::Config::from_str(&s).expect("Invalid POSTGRES_DSN"),
tokio_postgres::NoTls,
);
let db_pool = bb8::Pool::builder()
.build(manager)
async fn load_updates(conn: Pool, tree: Tree) {
let mut listener = sqlx::postgres::PgListener::connect_with(&conn)
.await
.expect("Unable to build Postgres pool");
.unwrap();
listener.listen("fuzzysearch_hash_added").await.unwrap();
let tree: Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(Hamming)));
let mut max_id = 0;
let conn = db_pool.get().await.unwrap();
let new_tree = create_tree(&conn).await;
let mut lock = tree.write().await;
conn.query("SELECT id, hash FROM hashes", &[])
*lock = new_tree;
drop(lock);
tokio::spawn(async move {
loop {
while let Some(notification) = listener
.try_recv()
.await
.unwrap()
.into_iter()
.for_each(|row| {
let id: i32 = row.get(0);
let hash: i64 = row.get(1);
let bytes = hash.to_be_bytes();
.expect("Unable to recv notification")
{
let payload: HashRow = serde_json::from_str(notification.payload()).unwrap();
tracing::debug!(id = payload.id, "Adding new hash to tree");
if id > max_id {
max_id = id;
}
lock.add(Node { id, hash: bytes });
let mut lock = tree.write().await;
lock.add(Node {
id: payload.id,
hash: payload.hash.to_be_bytes(),
});
drop(lock);
drop(conn);
let tree_clone = tree.clone();
let pool_clone = db_pool.clone();
tokio::spawn(async move {
use futures_util::StreamExt;
let max_id = std::sync::atomic::AtomicI32::new(max_id);
let tree = tree_clone;
let pool = pool_clone;
let order = std::sync::atomic::Ordering::SeqCst;
let interval = tokio::time::interval(std::time::Duration::from_secs(30));
interval
.for_each(|_| async {
tracing::debug!("Refreshing hashes");
let conn = pool.get().await.unwrap();
let mut lock = tree.write().await;
let id = max_id.load(order);
let mut count = 0;
conn.query("SELECT id, hash FROM hashes WHERE hashes.id > $1", &[&id])
.await
.unwrap()
.into_iter()
.for_each(|row| {
let id: i32 = row.get(0);
let hash: i64 = row.get(1);
let bytes = hash.to_be_bytes();
if id > max_id.load(order) {
max_id.store(id, order);
}
lock.add(Node { id, hash: bytes });
count += 1;
tracing::error!("Lost connection to Postgres, recreating tree");
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
let new_tree = create_tree(&conn).await;
let mut lock = tree.write().await;
*lock = new_tree;
drop(lock);
tracing::info!("Replaced tree");
}
});
tracing::trace!("Added {} new hashes", count);
})
.await;
});
let log = warp::log("fuzzysearch");
let cors = warp::cors()
.allow_any_origin()
.allow_headers(vec!["x-api-key"])
.allow_methods(vec!["GET", "POST"]);
let options = warp::options().map(|| "");
let api = options.or(filters::search(db_pool, tree));
let routes = api
.or(warp::path::end()
.map(|| warp::redirect(warp::http::Uri::from_static("https://fuzzysearch.net"))))
.with(log)
.with(cors)
.recover(handlers::handle_rejection);
warp::serve(routes).run(([0, 0, 0, 0], 8080)).await;
}
type Pool = bb8::Pool<bb8_postgres::PostgresConnectionManager<tokio_postgres::NoTls>>;

View File

@ -1,46 +1,48 @@
use crate::types::*;
use crate::utils::extract_rows;
use crate::{Pool, Tree};
use lazy_static::lazy_static;
use prometheus::{register_histogram, Histogram};
use tracing_futures::Instrument;
use fuzzysearch_common::types::SearchResult;
use crate::types::*;
use crate::{Pool, Tree};
use fuzzysearch_common::types::{SearchResult, SiteInfo};
pub type DB<'a> =
&'a bb8::PooledConnection<'a, bb8_postgres::PostgresConnectionManager<tokio_postgres::NoTls>>;
lazy_static! {
static ref IMAGE_LOOKUP_DURATION: Histogram = register_histogram!(
"fuzzysearch_api_image_lookup_seconds",
"Duration to perform an image lookup"
)
.unwrap();
static ref IMAGE_QUERY_DURATION: Histogram = register_histogram!(
"fuzzysearch_api_image_query_seconds",
"Duration to perform a single image lookup query"
)
.unwrap();
}
#[tracing::instrument(skip(db))]
pub async fn lookup_api_key(key: &str, db: DB<'_>) -> Option<ApiKey> {
let rows = db
.query(
pub async fn lookup_api_key(key: &str, db: &sqlx::PgPool) -> Option<ApiKey> {
sqlx::query_as!(
ApiKey,
"SELECT
api_key.id,
api_key.name_limit,
api_key.image_limit,
api_key.hash_limit,
api_key.name,
account.email
account.email owner_email
FROM
api_key
JOIN account
ON account.id = api_key.user_id
WHERE
api_key.key = $1",
&[&key],
api_key.key = $1
",
key
)
.fetch_optional(db)
.await
.expect("Unable to query API keys");
match rows.into_iter().next() {
Some(row) => Some(ApiKey {
id: row.get(0),
name_limit: row.get(1),
image_limit: row.get(2),
hash_limit: row.get(3),
name: row.get(4),
owner_email: row.get(5),
}),
_ => None,
}
.ok()
.flatten()
}
#[tracing::instrument(skip(pool, tree))]
@ -50,7 +52,7 @@ pub async fn image_query(
hashes: Vec<i64>,
distance: i64,
hash: Option<Vec<u8>>,
) -> Result<Vec<SearchResult>, tokio_postgres::Error> {
) -> Result<Vec<SearchResult>, sqlx::Error> {
let mut results = image_query_sync(pool, tree, hashes, distance, hash);
let mut matches = Vec::new();
@ -68,19 +70,30 @@ pub fn image_query_sync(
hashes: Vec<i64>,
distance: i64,
hash: Option<Vec<u8>>,
) -> tokio::sync::mpsc::Receiver<Result<Vec<SearchResult>, tokio_postgres::Error>> {
let (mut tx, rx) = tokio::sync::mpsc::channel(50);
) -> tokio::sync::mpsc::Receiver<Result<Vec<SearchResult>, sqlx::Error>> {
let (tx, rx) = tokio::sync::mpsc::channel(50);
tokio::spawn(async move {
let db = pool.get().await.unwrap();
let db = pool;
for query_hash in hashes {
let mut seen = std::collections::HashSet::new();
let _timer = IMAGE_LOOKUP_DURATION.start_timer();
let node = crate::Node::query(query_hash.to_be_bytes());
let lock = tree.read().await;
let items = lock.find(&node, distance as u64);
for (_dist, item) in items {
let query = db.query("SELECT
for (dist, item) in items {
if seen.contains(&item.id) {
continue;
}
seen.insert(item.id);
let _timer = IMAGE_QUERY_DURATION.start_timer();
let row = sqlx::query!("SELECT
hashes.id,
hashes.hash,
hashes.furaffinity_id,
@ -106,7 +119,16 @@ pub fn image_query_sync(
END file_id,
CASE
WHEN e621_id IS NOT NULL THEN ARRAY(SELECT jsonb_array_elements_text(e.data->'sources'))
END sources
END sources,
CASE
WHEN furaffinity_id IS NOT NULL THEN (f.rating)
WHEN e621_id IS NOT NULL THEN (e.data->>'rating')
WHEN twitter_id IS NOT NULL THEN
CASE
WHEN (tw.data->'possibly_sensitive')::boolean IS true THEN 'adult'
WHEN (tw.data->'possibly_sensitive')::boolean IS false THEN 'general'
END
END rating
FROM
hashes
LEFT JOIN LATERAL (
@ -133,14 +155,45 @@ pub fn image_query_sync(
tweet_media.hash <@ (hashes.hash, 0)
LIMIT 1
) tm ON hashes.twitter_id IS NOT NULL
WHERE hashes.id = $1", &[&item.id]).await;
let rows = query.map(|rows| {
extract_rows(rows, hash.as_deref()).into_iter().map(|mut file| {
file.searched_hash = Some(query_hash);
file
}).collect()
});
tx.send(rows).await.unwrap();
WHERE hashes.id = $1", item.id).map(|row| {
let (site_id, site_info) = if let Some(fa_id) = row.furaffinity_id {
(
fa_id as i64,
Some(SiteInfo::FurAffinity {
file_id: row.file_id.unwrap(),
})
)
} else if let Some(e621_id) = row.e621_id {
(
e621_id as i64,
Some(SiteInfo::E621 {
sources: row.sources,
})
)
} else if let Some(twitter_id) = row.twitter_id {
(twitter_id, Some(SiteInfo::Twitter))
} else {
(-1, None)
};
let file = SearchResult {
id: row.id,
site_id,
site_info,
rating: row.rating.and_then(|rating| rating.parse().ok()),
site_id_str: site_id.to_string(),
url: row.url.unwrap_or_default(),
hash: Some(row.hash),
distance: Some(dist),
artists: row.artists,
filename: row.filename.unwrap_or_default(),
searched_hash: Some(query_hash),
};
vec![file]
}).fetch_one(&db).await;
tx.send(row).await.unwrap();
}
}
}.in_current_span());

View File

@ -10,7 +10,7 @@ use fuzzysearch_common::types::SearchResult;
pub struct ApiKey {
pub id: i32,
pub name: Option<String>,
pub owner_email: Option<String>,
pub owner_email: String,
pub name_limit: i16,
pub image_limit: i16,
pub hash_limit: i16,
@ -22,7 +22,7 @@ pub enum RateLimit {
/// This key is limited, we should deny the request.
Limited,
/// This key is available, contains the number of requests made.
Available(i16),
Available((i16, i16)),
}
#[derive(Debug, Deserialize)]
@ -68,3 +68,8 @@ pub struct HashSearchOpts {
pub struct HandleOpts {
pub twitter: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct UrlSearchOpts {
pub url: String,
}

View File

@ -1,7 +1,15 @@
use crate::models::DB;
use crate::types::*;
use lazy_static::lazy_static;
use prometheus::{register_int_counter_vec, IntCounterVec};
use fuzzysearch_common::types::{SearchResult, SiteInfo};
lazy_static! {
pub static ref RATE_LIMIT_STATUS: IntCounterVec = register_int_counter_vec!(
"fuzzysearch_api_rate_limit_count",
"Number of allowed and rate limited requests",
&["status"]
)
.unwrap();
}
#[macro_export]
macro_rules! rate_limit {
@ -9,18 +17,48 @@ macro_rules! rate_limit {
rate_limit!($api_key, $db, $limit, $group, 1)
};
($api_key:expr, $db:expr, $limit:tt, $group:expr, $incr_by:expr) => {
let api_key = crate::models::lookup_api_key($api_key, $db)
.await
.ok_or_else(|| warp::reject::custom(Error::ApiKey))?;
($api_key:expr, $db:expr, $limit:tt, $group:expr, $incr_by:expr) => {{
let api_key = match crate::models::lookup_api_key($api_key, $db).await {
Some(api_key) => api_key,
None => return Ok(Box::new(Error::ApiKey)),
};
let rate_limit =
crate::utils::update_rate_limit($db, api_key.id, api_key.$limit, $group, $incr_by)
let rate_limit = match crate::utils::update_rate_limit(
$db,
api_key.id,
api_key.$limit,
$group,
$incr_by,
)
.await
.map_err(crate::handlers::map_postgres_err)?;
{
Ok(rate_limit) => rate_limit,
Err(err) => return Ok(Box::new(Error::Postgres(err))),
};
if rate_limit == crate::types::RateLimit::Limited {
return Err(warp::reject::custom(Error::RateLimit));
match rate_limit {
crate::types::RateLimit::Limited => {
crate::utils::RATE_LIMIT_STATUS
.with_label_values(&["limited"])
.inc();
return Ok(Box::new(Error::RateLimit));
}
crate::types::RateLimit::Available(count) => {
crate::utils::RATE_LIMIT_STATUS
.with_label_values(&["allowed"])
.inc();
count
}
}
}};
}
#[macro_export]
macro_rules! early_return {
($val:expr) => {
match $val {
Ok(val) => val,
Err(err) => return Ok(Box::new(Error::from(err))),
}
};
}
@ -33,18 +71,17 @@ macro_rules! rate_limit {
/// joined requests.
#[tracing::instrument(skip(db))]
pub async fn update_rate_limit(
db: DB<'_>,
db: &sqlx::PgPool,
key_id: i32,
key_group_limit: i16,
group_name: &'static str,
incr_by: i16,
) -> Result<RateLimit, tokio_postgres::Error> {
) -> Result<RateLimit, sqlx::Error> {
let now = chrono::Utc::now();
let timestamp = now.timestamp();
let time_window = timestamp - (timestamp % 60);
let rows = db
.query(
let count: i16 = sqlx::query_scalar!(
"INSERT INTO
rate_limit (api_key_id, time_window, group_name, count)
VALUES
@ -52,66 +89,20 @@ pub async fn update_rate_limit(
ON CONFLICT ON CONSTRAINT unique_window
DO UPDATE set count = rate_limit.count + $4
RETURNING rate_limit.count",
&[&key_id, &time_window, &group_name, &incr_by],
key_id,
time_window,
group_name,
incr_by
)
.fetch_one(db)
.await?;
let count: i16 = rows[0].get(0);
if count > key_group_limit {
Ok(RateLimit::Limited)
} else {
Ok(RateLimit::Available(count))
Ok(RateLimit::Available((
key_group_limit - count,
key_group_limit,
)))
}
}
pub fn extract_rows<'a>(
rows: Vec<tokio_postgres::Row>,
hash: Option<&'a [u8]>,
) -> impl IntoIterator<Item = SearchResult> + 'a {
rows.into_iter().map(move |row| {
let dbhash: i64 = row.get("hash");
let dbbytes = dbhash.to_be_bytes();
let (furaffinity_id, e621_id, twitter_id): (Option<i32>, Option<i32>, Option<i64>) = (
row.get("furaffinity_id"),
row.get("e621_id"),
row.get("twitter_id"),
);
let (site_id, site_info) = if let Some(fa_id) = furaffinity_id {
(
fa_id as i64,
Some(SiteInfo::FurAffinity {
file_id: row.get("file_id"),
}),
)
} else if let Some(e6_id) = e621_id {
(
e6_id as i64,
Some(SiteInfo::E621 {
sources: row.get("sources"),
}),
)
} else if let Some(t_id) = twitter_id {
(t_id, Some(SiteInfo::Twitter))
} else {
(-1, None)
};
SearchResult {
id: row.get("id"),
site_id,
site_info,
site_id_str: site_id.to_string(),
url: row.get("url"),
hash: Some(dbhash),
distance: hash
.map(|hash| hamming::distance_fast(&dbbytes, &hash).ok())
.flatten(),
artists: row.get("artists"),
filename: row.get("filename"),
searched_hash: None,
}
})
}