mirror of
https://github.com/Syfaro/fuzzysearch.git
synced 2024-11-05 14:32:56 +00:00
Merge branch 'master' into unify
This commit is contained in:
commit
8914227f23
@ -33,4 +33,8 @@ steps:
|
||||
from_secret: sccache_s3_endpoint
|
||||
SCCACHE_S3_USE_SSL: true
|
||||
|
||||
---
|
||||
kind: signature
|
||||
hmac: 665dab5e07086669c4b215ed86faa0e1e63c495b0bf020099fb1edd33757618b
|
||||
|
||||
...
|
||||
|
125
Cargo.lock
generated
125
Cargo.lock
generated
@ -61,6 +61,27 @@ version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b"
|
||||
|
||||
[[package]]
|
||||
name = "async-stream"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3670df70cbc01729f901f94c887814b3c68db038aad1329a418bae178bc5295c"
|
||||
dependencies = [
|
||||
"async-stream-impl",
|
||||
"futures-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-stream-impl"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a3548b8efc9f8e8a5a0a2808c5bd8451a9031b9e5b879a79590304ae928b0a70"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.42"
|
||||
@ -105,30 +126,6 @@ version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd"
|
||||
|
||||
[[package]]
|
||||
name = "bb8"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "374bba43fc924d90393ee7768e6f75d223a98307a488fe5bc34b66c3e96932a6"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"futures",
|
||||
"tokio 0.2.25",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bb8-postgres"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39a233af6ea3952e20d01863c87b4f6689b2f806249688b0908b5f02d4fa41ac"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bb8",
|
||||
"futures",
|
||||
"tokio 0.2.25",
|
||||
"tokio-postgres 0.5.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bindgen"
|
||||
version = "0.54.0"
|
||||
@ -835,25 +832,27 @@ dependencies = [
|
||||
name = "fuzzysearch"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bb8",
|
||||
"bb8-postgres",
|
||||
"async-stream",
|
||||
"bk-tree",
|
||||
"bytes 0.5.6",
|
||||
"bytes 1.0.1",
|
||||
"chrono",
|
||||
"ffmpeg-next",
|
||||
"futures",
|
||||
"futures-util",
|
||||
"fuzzysearch-common",
|
||||
"hamming",
|
||||
"hyper 0.14.4",
|
||||
"image",
|
||||
"img_hash",
|
||||
"infer",
|
||||
"lazy_static",
|
||||
"opentelemetry",
|
||||
"opentelemetry-jaeger",
|
||||
"prometheus 0.11.0",
|
||||
"reqwest 0.11.1",
|
||||
"serde",
|
||||
"tokio 0.2.25",
|
||||
"tokio-postgres 0.5.5",
|
||||
"serde_json",
|
||||
"sqlx 0.5.1",
|
||||
"tokio 1.2.0",
|
||||
"tracing",
|
||||
"tracing-futures",
|
||||
"tracing-opentelemetry",
|
||||
@ -1338,11 +1337,11 @@ checksum = "8906512588cd815b8f759fd0ac11de2a84c985c0f792f70df611e9325c270c1f"
|
||||
|
||||
[[package]]
|
||||
name = "input_buffer"
|
||||
version = "0.3.1"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "19a8a95243d5a0398cae618ec29477c6e3cb631152be5c19481f80bc71559754"
|
||||
checksum = "f97967975f448f1a7ddb12b0bc41069d09ed6a1c161a92687e057325db35d413"
|
||||
dependencies = [
|
||||
"bytes 0.5.6",
|
||||
"bytes 1.0.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2530,7 +2529,7 @@ dependencies = [
|
||||
"pin-project-lite 0.2.4",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded 0.7.0",
|
||||
"serde_urlencoded",
|
||||
"tokio 0.2.25",
|
||||
"tokio-tls",
|
||||
"url",
|
||||
@ -2565,7 +2564,7 @@ dependencies = [
|
||||
"pin-project-lite 0.2.4",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded 0.7.0",
|
||||
"serde_urlencoded",
|
||||
"tokio 1.2.0",
|
||||
"tokio-native-tls",
|
||||
"url",
|
||||
@ -2824,18 +2823,6 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_urlencoded"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ec5d77e2d4c73717816afac02670d5c4f534ea95ed430442cad02e7a6e32c97"
|
||||
dependencies = [
|
||||
"dtoa",
|
||||
"itoa",
|
||||
"serde",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_urlencoded"
|
||||
version = "0.7.0"
|
||||
@ -3516,14 +3503,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.11.0"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d9e878ad426ca286e4dcae09cbd4e1973a7f8987d97570e2469703dd7f5720c"
|
||||
checksum = "e1a5f475f1b9d077ea1017ecbc60890fda8e54942d680ca0b1d2b47cfa2d861b"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"pin-project 0.4.27",
|
||||
"tokio 0.2.25",
|
||||
"pin-project 1.0.5",
|
||||
"tokio 1.2.0",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
@ -3699,18 +3686,18 @@ checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.11.1"
|
||||
version = "0.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0308d80d86700c5878b9ef6321f020f29b1bb9d5ff3cab25e75e23f3a492a23"
|
||||
checksum = "8ada8297e8d70872fa9a551d93250a9f407beb9f37ef86494eb20012a2ff7c24"
|
||||
dependencies = [
|
||||
"base64 0.12.3",
|
||||
"base64 0.13.0",
|
||||
"byteorder",
|
||||
"bytes 0.5.6",
|
||||
"bytes 1.0.1",
|
||||
"http",
|
||||
"httparse",
|
||||
"input_buffer",
|
||||
"log",
|
||||
"rand 0.7.3",
|
||||
"rand 0.8.3",
|
||||
"sha-1 0.9.4",
|
||||
"url",
|
||||
"utf-8",
|
||||
@ -3812,12 +3799,6 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "urlencoding"
|
||||
version = "1.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c9232eb53352b4442e40d7900465dfc534e8cb2dc8f18656fcb2ac16112b5593"
|
||||
|
||||
[[package]]
|
||||
name = "utf-8"
|
||||
version = "0.7.5"
|
||||
@ -3848,30 +3829,32 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "warp"
|
||||
version = "0.2.5"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f41be6df54c97904af01aa23e613d4521eed7ab23537cede692d4058f6449407"
|
||||
checksum = "3dafd0aac2818a94a34df0df1100a7356c493d8ede4393875fd0b5c51bb6bc80"
|
||||
dependencies = [
|
||||
"bytes 0.5.6",
|
||||
"bytes 1.0.1",
|
||||
"futures",
|
||||
"headers",
|
||||
"http",
|
||||
"hyper 0.13.10",
|
||||
"hyper 0.14.4",
|
||||
"log",
|
||||
"mime",
|
||||
"mime_guess",
|
||||
"multipart",
|
||||
"pin-project 0.4.27",
|
||||
"percent-encoding",
|
||||
"pin-project 1.0.5",
|
||||
"scoped-tls",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded 0.6.1",
|
||||
"tokio 0.2.25",
|
||||
"serde_urlencoded",
|
||||
"tokio 1.2.0",
|
||||
"tokio-stream",
|
||||
"tokio-tungstenite",
|
||||
"tokio-util 0.6.3",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
"tracing-futures",
|
||||
"urlencoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1,5 +1,28 @@
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Rating {
|
||||
General,
|
||||
Mature,
|
||||
Adult,
|
||||
}
|
||||
|
||||
impl std::str::FromStr for Rating {
|
||||
type Err = &'static str;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let rating = match s {
|
||||
"g" | "s" | "general" => Self::General,
|
||||
"m" | "q" | "mature" => Self::Mature,
|
||||
"a" | "e" | "adult" => Self::Adult,
|
||||
_ => return Err("unknown rating"),
|
||||
};
|
||||
|
||||
Ok(rating)
|
||||
}
|
||||
}
|
||||
|
||||
/// A general type for every result in a search.
|
||||
#[derive(Debug, Default, Serialize)]
|
||||
pub struct SearchResult {
|
||||
@ -11,6 +34,7 @@ pub struct SearchResult {
|
||||
pub url: String,
|
||||
pub filename: String,
|
||||
pub artists: Option<Vec<String>>,
|
||||
pub rating: Option<Rating>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(flatten)]
|
||||
|
@ -165,7 +165,7 @@ async fn main() {
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
panic!(e);
|
||||
panic!("postgres connection error: {:?}", e);
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -9,29 +9,33 @@ tracing = "0.1"
|
||||
tracing-subscriber = "0.2"
|
||||
tracing-futures = "0.2"
|
||||
|
||||
prometheus = { version = "0.11", features = ["process"] }
|
||||
lazy_static = "1"
|
||||
|
||||
opentelemetry = "0.6"
|
||||
opentelemetry-jaeger = "0.5"
|
||||
tracing-opentelemetry = "0.5"
|
||||
|
||||
tokio = { version = "0.2", features = ["full"] }
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
async-stream = "0.3"
|
||||
|
||||
futures = "0.3"
|
||||
|
||||
anyhow = "1"
|
||||
chrono = "0.4"
|
||||
bytes = "0.5"
|
||||
infer = { version = "0.3", default-features = false }
|
||||
bytes = "1"
|
||||
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
warp = "0.2"
|
||||
serde_json = "1"
|
||||
|
||||
tokio-postgres = "0.5"
|
||||
bb8 = "0.4"
|
||||
bb8-postgres = "0.4"
|
||||
warp = "0.3"
|
||||
reqwest = "0.11"
|
||||
hyper = "0.14"
|
||||
|
||||
image = "0.23"
|
||||
sqlx = { version = "0.5", features = ["runtime-tokio-native-tls", "postgres", "macros", "json", "offline"] }
|
||||
|
||||
infer = { version = "0.3", default-features = false }
|
||||
ffmpeg-next = "4"
|
||||
|
||||
image = "0.23"
|
||||
img_hash = "3"
|
||||
hamming = "0.1"
|
||||
|
||||
|
@ -1,10 +1,14 @@
|
||||
FROM rust:1-slim AS builder
|
||||
WORKDIR /src
|
||||
ENV SQLX_OFFLINE=true
|
||||
RUN apt-get update -y && apt-get install -y libssl-dev pkg-config
|
||||
COPY . .
|
||||
RUN cargo install --root / --path .
|
||||
|
||||
FROM debian:buster-slim
|
||||
EXPOSE 8080
|
||||
EXPOSE 8080 8081
|
||||
ENV METRICS_HOST=0.0.0.0:8081
|
||||
WORKDIR /app
|
||||
RUN apt-get update -y && apt-get install -y openssl ca-certificates && rm -rf /var/lib/apt/lists/*
|
||||
COPY --from=builder /bin/fuzzysearch /bin/fuzzysearch
|
||||
CMD ["/bin/fuzzysearch"]
|
||||
|
200
fuzzysearch/sqlx-data.json
Normal file
200
fuzzysearch/sqlx-data.json
Normal file
@ -0,0 +1,200 @@
|
||||
{
|
||||
"db": "PostgreSQL",
|
||||
"1984ce60f052d6a29638f8e05b35671b8edfbf273783d4b843ebd35cbb8a391f": {
|
||||
"query": "INSERT INTO\n rate_limit (api_key_id, time_window, group_name, count)\n VALUES\n ($1, $2, $3, $4)\n ON CONFLICT ON CONSTRAINT unique_window\n DO UPDATE set count = rate_limit.count + $4\n RETURNING rate_limit.count",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "count",
|
||||
"type_info": "Int2"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int4",
|
||||
"Int8",
|
||||
"Text",
|
||||
"Int2"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"1bd0057782de5a3b41f90081a31d24d14bb70299391050c3404742a6d2915d9e": {
|
||||
"query": "SELECT\n hashes.id,\n hashes.hash,\n hashes.furaffinity_id,\n hashes.e621_id,\n hashes.twitter_id,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (f.url)\n WHEN e621_id IS NOT NULL THEN (e.data->'file'->>'url')\n WHEN twitter_id IS NOT NULL THEN (tm.url)\n END url,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (f.filename)\n WHEN e621_id IS NOT NULL THEN ((e.data->'file'->>'md5') || '.' || (e.data->'file'->>'ext'))\n WHEN twitter_id IS NOT NULL THEN (SELECT split_part(split_part(tm.url, '/', 5), ':', 1))\n END filename,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (ARRAY(SELECT f.name))\n WHEN e621_id IS NOT NULL THEN ARRAY(SELECT jsonb_array_elements_text(e.data->'tags'->'artist'))\n WHEN twitter_id IS NOT NULL THEN ARRAY(SELECT tw.data->'user'->>'screen_name')\n END artists,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (f.file_id)\n END file_id,\n CASE\n WHEN e621_id IS NOT NULL THEN ARRAY(SELECT jsonb_array_elements_text(e.data->'sources'))\n END sources,\n CASE\n WHEN furaffinity_id IS NOT NULL THEN (f.rating)\n WHEN e621_id IS NOT NULL THEN (e.data->>'rating')\n WHEN twitter_id IS NOT NULL THEN\n CASE\n WHEN (tw.data->'possibly_sensitive')::boolean IS true THEN 'adult'\n WHEN (tw.data->'possibly_sensitive')::boolean IS false THEN 'general'\n END\n END rating\n FROM\n hashes\n LEFT JOIN LATERAL (\n SELECT *\n FROM submission\n JOIN artist ON submission.artist_id = artist.id\n WHERE submission.id = hashes.furaffinity_id\n ) f ON hashes.furaffinity_id IS NOT NULL\n LEFT JOIN LATERAL (\n SELECT *\n FROM e621\n WHERE e621.id = hashes.e621_id\n ) e ON hashes.e621_id IS NOT NULL\n LEFT JOIN LATERAL (\n SELECT *\n FROM tweet\n WHERE tweet.id = hashes.twitter_id\n ) tw ON hashes.twitter_id IS NOT NULL\n LEFT JOIN LATERAL (\n SELECT *\n FROM tweet_media\n WHERE\n tweet_media.tweet_id = hashes.twitter_id AND\n tweet_media.hash <@ (hashes.hash, 0)\n LIMIT 1\n ) tm ON hashes.twitter_id IS NOT NULL\n WHERE hashes.id = $1",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "hash",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "furaffinity_id",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "e621_id",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "twitter_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "url",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "filename",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 7,
|
||||
"name": "artists",
|
||||
"type_info": "TextArray"
|
||||
},
|
||||
{
|
||||
"ordinal": 8,
|
||||
"name": "file_id",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 9,
|
||||
"name": "sources",
|
||||
"type_info": "TextArray"
|
||||
},
|
||||
{
|
||||
"ordinal": 10,
|
||||
"name": "rating",
|
||||
"type_info": "Bpchar"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int4"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
]
|
||||
}
|
||||
},
|
||||
"659ee9ddc1c5ccd42ba9dc1617440544c30ece449ba3ba7f9d39f447b8af3cfe": {
|
||||
"query": "SELECT\n api_key.id,\n api_key.name_limit,\n api_key.image_limit,\n api_key.hash_limit,\n api_key.name,\n account.email owner_email\n FROM\n api_key\n JOIN account\n ON account.id = api_key.user_id\n WHERE\n api_key.key = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "name_limit",
|
||||
"type_info": "Int2"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "image_limit",
|
||||
"type_info": "Int2"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "hash_limit",
|
||||
"type_info": "Int2"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "name",
|
||||
"type_info": "Varchar"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "owner_email",
|
||||
"type_info": "Varchar"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"6b8d304fc40fa539ae671e6e24e7978ad271cb7a1cafb20fc4b4096a958d790f": {
|
||||
"query": "SELECT exists(SELECT 1 FROM twitter_user WHERE lower(data->>'screen_name') = lower($1))",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "exists",
|
||||
"type_info": "Bool"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
null
|
||||
]
|
||||
}
|
||||
},
|
||||
"fe60be66b2d8a8f02b3bfe06d1f0e57e4bb07e80cba1b379a5f17f6cbd8b075c": {
|
||||
"query": "SELECT id, hash FROM hashes",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "hash",
|
||||
"type_info": "Int8"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": []
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
@ -10,10 +10,11 @@ pub fn search(
|
||||
) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
|
||||
search_image(db.clone(), tree.clone())
|
||||
.or(search_hashes(db.clone(), tree.clone()))
|
||||
.or(stream_search_image(db.clone(), tree))
|
||||
.or(stream_search_image(db.clone(), tree.clone()))
|
||||
.or(search_file(db.clone()))
|
||||
.or(search_video(db.clone()))
|
||||
.or(check_handle(db))
|
||||
.or(check_handle(db.clone()))
|
||||
.or(search_image_by_url(db, tree))
|
||||
}
|
||||
|
||||
pub fn search_file(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
|
||||
@ -55,6 +56,19 @@ pub fn search_image(
|
||||
})
|
||||
}
|
||||
|
||||
pub fn search_image_by_url(
|
||||
db: Pool,
|
||||
tree: Tree,
|
||||
) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
|
||||
warp::path("url")
|
||||
.and(warp::get())
|
||||
.and(warp::query::<UrlSearchOpts>())
|
||||
.and(with_pool(db))
|
||||
.and(with_tree(tree))
|
||||
.and(with_api_key())
|
||||
.and_then(handlers::search_image_by_url)
|
||||
}
|
||||
|
||||
pub fn search_hashes(
|
||||
db: Pool,
|
||||
tree: Tree,
|
||||
|
@ -1,46 +1,92 @@
|
||||
use crate::models::{image_query, image_query_sync};
|
||||
use crate::types::*;
|
||||
use crate::{rate_limit, Pool, Tree};
|
||||
use lazy_static::lazy_static;
|
||||
use prometheus::{register_histogram, register_int_counter, Histogram, IntCounter};
|
||||
use std::convert::TryInto;
|
||||
use tracing::{span, warn};
|
||||
use tracing_futures::Instrument;
|
||||
use warp::{reject, Rejection, Reply};
|
||||
use warp::{Rejection, Reply};
|
||||
|
||||
use crate::models::{image_query, image_query_sync};
|
||||
use crate::types::*;
|
||||
use crate::{early_return, rate_limit, Pool, Tree};
|
||||
use fuzzysearch_common::types::{SearchResult, SiteInfo};
|
||||
|
||||
fn map_bb8_err(err: bb8::RunError<tokio_postgres::Error>) -> Rejection {
|
||||
reject::custom(Error::from(err))
|
||||
}
|
||||
|
||||
fn map_postgres_err(err: tokio_postgres::Error) -> Rejection {
|
||||
reject::custom(Error::from(err))
|
||||
lazy_static! {
|
||||
static ref IMAGE_HASH_DURATION: Histogram = register_histogram!(
|
||||
"fuzzysearch_api_image_hash_seconds",
|
||||
"Duration to perform an image hash operation"
|
||||
)
|
||||
.unwrap();
|
||||
static ref IMAGE_URL_DOWNLOAD_DURATION: Histogram = register_histogram!(
|
||||
"fuzzysearch_api_image_url_download_seconds",
|
||||
"Duration to download an image from a provided URL"
|
||||
)
|
||||
.unwrap();
|
||||
static ref UNHANDLED_REJECTIONS: IntCounter = register_int_counter!(
|
||||
"fuzzysearch_api_unhandled_rejections_count",
|
||||
"Number of unhandled HTTP rejections"
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Error {
|
||||
BB8(bb8::RunError<tokio_postgres::Error>),
|
||||
Postgres(tokio_postgres::Error),
|
||||
Postgres(sqlx::Error),
|
||||
Reqwest(reqwest::Error),
|
||||
InvalidData,
|
||||
InvalidImage,
|
||||
ApiKey,
|
||||
RateLimit,
|
||||
}
|
||||
|
||||
impl From<bb8::RunError<tokio_postgres::Error>> for Error {
|
||||
fn from(err: bb8::RunError<tokio_postgres::Error>) -> Self {
|
||||
Error::BB8(err)
|
||||
impl warp::Reply for Error {
|
||||
fn into_response(self) -> warp::reply::Response {
|
||||
let msg = match self {
|
||||
Error::Postgres(_) | Error::Reqwest(_) => ErrorMessage {
|
||||
code: 500,
|
||||
message: "Internal server error".to_string(),
|
||||
},
|
||||
Error::InvalidData => ErrorMessage {
|
||||
code: 400,
|
||||
message: "Invalid data provided".to_string(),
|
||||
},
|
||||
Error::InvalidImage => ErrorMessage {
|
||||
code: 400,
|
||||
message: "Invalid image provided".to_string(),
|
||||
},
|
||||
Error::ApiKey => ErrorMessage {
|
||||
code: 401,
|
||||
message: "Invalid API key".to_string(),
|
||||
},
|
||||
Error::RateLimit => ErrorMessage {
|
||||
code: 429,
|
||||
message: "Too many requests".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let body = hyper::body::Body::from(serde_json::to_string(&msg).unwrap());
|
||||
|
||||
warp::http::Response::builder()
|
||||
.status(msg.code)
|
||||
.body(body)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tokio_postgres::Error> for Error {
|
||||
fn from(err: tokio_postgres::Error) -> Self {
|
||||
impl From<sqlx::Error> for Error {
|
||||
fn from(err: sqlx::Error) -> Self {
|
||||
Error::Postgres(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl warp::reject::Reject for Error {}
|
||||
impl From<reqwest::Error> for Error {
|
||||
fn from(err: reqwest::Error) -> Self {
|
||||
Error::Reqwest(err)
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_field_bytes(form: warp::multipart::FormData, field: &str) -> bytes::BytesMut {
|
||||
use bytes::BufMut;
|
||||
use futures_util::StreamExt;
|
||||
use futures::StreamExt;
|
||||
|
||||
let parts: Vec<_> = form.collect().await;
|
||||
let mut parts = parts
|
||||
@ -66,6 +112,7 @@ async fn hash_input(form: warp::multipart::FormData) -> (i64, img_hash::ImageHas
|
||||
|
||||
let len = bytes.len();
|
||||
|
||||
let _timer = IMAGE_HASH_DURATION.start_timer();
|
||||
let hash = tokio::task::spawn_blocking(move || {
|
||||
let hasher = fuzzysearch_common::get_hasher();
|
||||
let image = image::load_from_memory(&bytes).unwrap();
|
||||
@ -74,6 +121,7 @@ async fn hash_input(form: warp::multipart::FormData) -> (i64, img_hash::ImageHas
|
||||
.instrument(span!(tracing::Level::TRACE, "hashing image", len))
|
||||
.await
|
||||
.unwrap();
|
||||
drop(_timer);
|
||||
|
||||
let mut buf: [u8; 8] = [0; 8];
|
||||
buf.copy_from_slice(&hash.as_bytes());
|
||||
@ -83,7 +131,7 @@ async fn hash_input(form: warp::multipart::FormData) -> (i64, img_hash::ImageHas
|
||||
|
||||
#[tracing::instrument(skip(form))]
|
||||
async fn hash_video(form: warp::multipart::FormData) -> Vec<[u8; 8]> {
|
||||
use bytes::buf::BufExt;
|
||||
use bytes::Buf;
|
||||
|
||||
let bytes = get_field_bytes(form, "video").await;
|
||||
|
||||
@ -106,21 +154,19 @@ async fn hash_video(form: warp::multipart::FormData) -> Vec<[u8; 8]> {
|
||||
pub async fn search_image(
|
||||
form: warp::multipart::FormData,
|
||||
opts: ImageSearchOpts,
|
||||
pool: Pool,
|
||||
db: Pool,
|
||||
tree: Tree,
|
||||
api_key: String,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
let db = pool.get().await.map_err(map_bb8_err)?;
|
||||
|
||||
rate_limit!(&api_key, &db, image_limit, "image");
|
||||
rate_limit!(&api_key, &db, hash_limit, "hash");
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
let image_remaining = rate_limit!(&api_key, &db, image_limit, "image");
|
||||
let hash_remaining = rate_limit!(&api_key, &db, hash_limit, "hash");
|
||||
|
||||
let (num, hash) = hash_input(form).await;
|
||||
|
||||
let mut items = {
|
||||
if opts.search_type == Some(ImageSearchType::Force) {
|
||||
image_query(
|
||||
pool.clone(),
|
||||
db.clone(),
|
||||
tree.clone(),
|
||||
vec![num],
|
||||
10,
|
||||
@ -130,7 +176,7 @@ pub async fn search_image(
|
||||
.unwrap()
|
||||
} else {
|
||||
let results = image_query(
|
||||
pool.clone(),
|
||||
db.clone(),
|
||||
tree.clone(),
|
||||
vec![num],
|
||||
0,
|
||||
@ -140,7 +186,7 @@ pub async fn search_image(
|
||||
.unwrap();
|
||||
if results.is_empty() && opts.search_type != Some(ImageSearchType::Exact) {
|
||||
image_query(
|
||||
pool.clone(),
|
||||
db.clone(),
|
||||
tree.clone(),
|
||||
vec![num],
|
||||
10,
|
||||
@ -166,7 +212,20 @@ pub async fn search_image(
|
||||
matches: items,
|
||||
};
|
||||
|
||||
Ok(warp::reply::json(&similarity))
|
||||
let resp = warp::http::Response::builder()
|
||||
.header("x-image-hash", num.to_string())
|
||||
.header("x-rate-limit-total-image", image_remaining.1.to_string())
|
||||
.header(
|
||||
"x-rate-limit-remaining-image",
|
||||
image_remaining.0.to_string(),
|
||||
)
|
||||
.header("x-rate-limit-total-hash", hash_remaining.1.to_string())
|
||||
.header("x-rate-limit-remaining-hash", hash_remaining.0.to_string())
|
||||
.header("content-type", "application/json")
|
||||
.body(serde_json::to_string(&similarity).unwrap())
|
||||
.unwrap();
|
||||
|
||||
Ok(Box::new(resp))
|
||||
}
|
||||
|
||||
pub async fn stream_image(
|
||||
@ -174,34 +233,36 @@ pub async fn stream_image(
|
||||
pool: Pool,
|
||||
tree: Tree,
|
||||
api_key: String,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
use futures_util::StreamExt;
|
||||
|
||||
let db = pool.get().await.map_err(map_bb8_err)?;
|
||||
|
||||
rate_limit!(&api_key, &db, image_limit, "image", 2);
|
||||
rate_limit!(&api_key, &db, hash_limit, "hash");
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
rate_limit!(&api_key, &pool, image_limit, "image", 2);
|
||||
rate_limit!(&api_key, &pool, hash_limit, "hash");
|
||||
|
||||
let (num, hash) = hash_input(form).await;
|
||||
|
||||
let event_stream = image_query_sync(
|
||||
let mut query = image_query_sync(
|
||||
pool.clone(),
|
||||
tree,
|
||||
vec![num],
|
||||
10,
|
||||
Some(hash.as_bytes().to_vec()),
|
||||
)
|
||||
.map(sse_matches);
|
||||
);
|
||||
|
||||
Ok(warp::sse::reply(event_stream))
|
||||
let event_stream = async_stream::stream! {
|
||||
while let Some(result) = query.recv().await {
|
||||
yield sse_matches(result);
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Box::new(warp::sse::reply(event_stream)))
|
||||
}
|
||||
|
||||
#[allow(clippy::unnecessary_wraps)]
|
||||
fn sse_matches(
|
||||
matches: Result<Vec<SearchResult>, tokio_postgres::Error>,
|
||||
) -> Result<impl warp::sse::ServerSentEvent, core::convert::Infallible> {
|
||||
matches: Result<Vec<SearchResult>, sqlx::Error>,
|
||||
) -> Result<warp::sse::Event, core::convert::Infallible> {
|
||||
let items = matches.unwrap();
|
||||
|
||||
Ok(warp::sse::json(items))
|
||||
Ok(warp::sse::Event::default().json_data(items).unwrap())
|
||||
}
|
||||
|
||||
pub async fn search_hashes(
|
||||
@ -209,9 +270,8 @@ pub async fn search_hashes(
|
||||
db: Pool,
|
||||
tree: Tree,
|
||||
api_key: String,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
let pool = db.clone();
|
||||
let db = db.get().await.map_err(map_bb8_err)?;
|
||||
|
||||
let hashes: Vec<i64> = opts
|
||||
.hashes
|
||||
@ -221,10 +281,10 @@ pub async fn search_hashes(
|
||||
.collect();
|
||||
|
||||
if hashes.is_empty() {
|
||||
return Err(warp::reject::custom(Error::InvalidData));
|
||||
return Ok(Box::new(Error::InvalidData));
|
||||
}
|
||||
|
||||
rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16);
|
||||
let image_remaining = rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16);
|
||||
|
||||
let mut results = image_query_sync(
|
||||
pool,
|
||||
@ -236,66 +296,107 @@ pub async fn search_hashes(
|
||||
let mut matches = Vec::new();
|
||||
|
||||
while let Some(r) = results.recv().await {
|
||||
matches.extend(r.map_err(|e| warp::reject::custom(Error::Postgres(e)))?);
|
||||
matches.extend(early_return!(r));
|
||||
}
|
||||
|
||||
Ok(warp::reply::json(&matches))
|
||||
let resp = warp::http::Response::builder()
|
||||
.header("x-rate-limit-total-image", image_remaining.1.to_string())
|
||||
.header(
|
||||
"x-rate-limit-remaining-image",
|
||||
image_remaining.0.to_string(),
|
||||
)
|
||||
.header("content-type", "application/json")
|
||||
.body(serde_json::to_string(&matches).unwrap())
|
||||
.unwrap();
|
||||
|
||||
Ok(Box::new(resp))
|
||||
}
|
||||
|
||||
pub async fn search_file(
|
||||
opts: FileSearchOpts,
|
||||
db: Pool,
|
||||
api_key: String,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
let db = db.get().await.map_err(map_bb8_err)?;
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
use sqlx::Row;
|
||||
|
||||
rate_limit!(&api_key, &db, name_limit, "file");
|
||||
let file_remaining = rate_limit!(&api_key, &db, name_limit, "file");
|
||||
|
||||
let (filter, val): (&'static str, &(dyn tokio_postgres::types::ToSql + Sync)) =
|
||||
if let Some(ref id) = opts.id {
|
||||
("file_id = $1", id)
|
||||
} else if let Some(ref name) = opts.name {
|
||||
("lower(filename) = lower($1)", name)
|
||||
} else if let Some(ref url) = opts.url {
|
||||
("lower(url) = lower($1)", url)
|
||||
} else {
|
||||
return Err(warp::reject::custom(Error::InvalidData));
|
||||
};
|
||||
let query = if let Some(ref id) = opts.id {
|
||||
sqlx::query(
|
||||
"SELECT
|
||||
submission.id,
|
||||
submission.url,
|
||||
submission.filename,
|
||||
submission.file_id,
|
||||
submission.rating,
|
||||
artist.name,
|
||||
hashes.id hash_id
|
||||
FROM
|
||||
submission
|
||||
JOIN artist
|
||||
ON artist.id = submission.artist_id
|
||||
JOIN hashes
|
||||
ON hashes.furaffinity_id = submission.id
|
||||
WHERE
|
||||
file_id = $1
|
||||
LIMIT 10",
|
||||
)
|
||||
.bind(id)
|
||||
} else if let Some(ref name) = opts.name {
|
||||
sqlx::query(
|
||||
"SELECT
|
||||
submission.id,
|
||||
submission.url,
|
||||
submission.filename,
|
||||
submission.file_id,
|
||||
submission.rating,
|
||||
artist.name,
|
||||
hashes.id hash_id
|
||||
FROM
|
||||
submission
|
||||
JOIN artist
|
||||
ON artist.id = submission.artist_id
|
||||
JOIN hashes
|
||||
ON hashes.furaffinity_id = submission.id
|
||||
WHERE
|
||||
lower(filename) = lower($1)
|
||||
LIMIT 10",
|
||||
)
|
||||
.bind(name)
|
||||
} else if let Some(ref url) = opts.url {
|
||||
sqlx::query(
|
||||
"SELECT
|
||||
submission.id,
|
||||
submission.url,
|
||||
submission.filename,
|
||||
submission.file_id,
|
||||
submission.rating,
|
||||
artist.name,
|
||||
hashes.id hash_id
|
||||
FROM
|
||||
submission
|
||||
JOIN artist
|
||||
ON artist.id = submission.artist_id
|
||||
JOIN hashes
|
||||
ON hashes.furaffinity_id = submission.id
|
||||
WHERE
|
||||
lower(url) = lower($1)
|
||||
LIMIT 10",
|
||||
)
|
||||
.bind(url)
|
||||
} else {
|
||||
return Ok(Box::new(Error::InvalidData));
|
||||
};
|
||||
|
||||
let query = format!(
|
||||
"SELECT
|
||||
submission.id,
|
||||
submission.url,
|
||||
submission.filename,
|
||||
submission.file_id,
|
||||
artist.name,
|
||||
hashes.id hash_id
|
||||
FROM
|
||||
submission
|
||||
JOIN artist
|
||||
ON artist.id = submission.artist_id
|
||||
JOIN hashes
|
||||
ON hashes.furaffinity_id = submission.id
|
||||
WHERE
|
||||
{}
|
||||
LIMIT 10",
|
||||
filter
|
||||
);
|
||||
|
||||
let matches: Vec<_> = db
|
||||
.query::<str>(&*query, &[val])
|
||||
.instrument(span!(tracing::Level::TRACE, "waiting for db"))
|
||||
.await
|
||||
.map_err(map_postgres_err)?
|
||||
.into_iter()
|
||||
let matches: Result<Vec<SearchResult>, _> = query
|
||||
.map(|row| SearchResult {
|
||||
id: row.get("hash_id"),
|
||||
site_id: row.get::<&str, i32>("id") as i64,
|
||||
site_id_str: row.get::<&str, i32>("id").to_string(),
|
||||
site_id: row.get::<i32, _>("id") as i64,
|
||||
site_id_str: row.get::<i32, _>("id").to_string(),
|
||||
url: row.get("url"),
|
||||
filename: row.get("filename"),
|
||||
artists: row
|
||||
.get::<&str, Option<String>>("name")
|
||||
.get::<Option<String>, _>("name")
|
||||
.map(|artist| vec![artist]),
|
||||
distance: None,
|
||||
hash: None,
|
||||
@ -303,10 +404,23 @@ pub async fn search_file(
|
||||
file_id: row.get("file_id"),
|
||||
}),
|
||||
searched_hash: None,
|
||||
rating: row
|
||||
.get::<Option<String>, _>("rating")
|
||||
.and_then(|rating| rating.parse().ok()),
|
||||
})
|
||||
.collect();
|
||||
.fetch_all(&db)
|
||||
.await;
|
||||
|
||||
Ok(warp::reply::json(&matches))
|
||||
let matches = early_return!(matches);
|
||||
|
||||
let resp = warp::http::Response::builder()
|
||||
.header("x-rate-limit-total-file", file_remaining.1.to_string())
|
||||
.header("x-rate-limit-remaining-file", file_remaining.0.to_string())
|
||||
.header("content-type", "application/json")
|
||||
.body(serde_json::to_string(&matches).unwrap())
|
||||
.unwrap();
|
||||
|
||||
Ok(Box::new(resp))
|
||||
}
|
||||
|
||||
pub async fn search_video(
|
||||
@ -319,56 +433,116 @@ pub async fn search_video(
|
||||
Ok(warp::reply::json(&hashes))
|
||||
}
|
||||
|
||||
pub async fn check_handle(opts: HandleOpts, db: Pool) -> Result<impl Reply, Rejection> {
|
||||
let db = db.get().await.map_err(map_bb8_err)?;
|
||||
|
||||
pub async fn check_handle(opts: HandleOpts, db: Pool) -> Result<Box<dyn Reply>, Rejection> {
|
||||
let exists = if let Some(handle) = opts.twitter {
|
||||
!db.query(
|
||||
"SELECT 1 FROM twitter_user WHERE lower(data->>'screen_name') = lower($1)",
|
||||
&[&handle],
|
||||
)
|
||||
.await
|
||||
.map_err(map_postgres_err)?
|
||||
.is_empty()
|
||||
let result = sqlx::query_scalar!("SELECT exists(SELECT 1 FROM twitter_user WHERE lower(data->>'screen_name') = lower($1))", handle)
|
||||
.fetch_optional(&db)
|
||||
.await
|
||||
.map(|row| row.flatten().unwrap_or(false));
|
||||
|
||||
early_return!(result)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
Ok(warp::reply::json(&exists))
|
||||
Ok(Box::new(warp::reply::json(&exists)))
|
||||
}
|
||||
|
||||
pub async fn search_image_by_url(
|
||||
opts: UrlSearchOpts,
|
||||
db: Pool,
|
||||
tree: Tree,
|
||||
api_key: String,
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
use bytes::BufMut;
|
||||
|
||||
let url = opts.url;
|
||||
|
||||
let image_remaining = rate_limit!(&api_key, &db, image_limit, "image");
|
||||
let hash_remaining = rate_limit!(&api_key, &db, hash_limit, "hash");
|
||||
|
||||
let _timer = IMAGE_URL_DOWNLOAD_DURATION.start_timer();
|
||||
|
||||
let mut resp = match reqwest::get(&url).await {
|
||||
Ok(resp) => resp,
|
||||
Err(_err) => return Ok(Box::new(Error::InvalidImage)),
|
||||
};
|
||||
|
||||
let content_length = resp
|
||||
.headers()
|
||||
.get("content-length")
|
||||
.and_then(|len| {
|
||||
String::from_utf8_lossy(len.as_bytes())
|
||||
.parse::<usize>()
|
||||
.ok()
|
||||
})
|
||||
.unwrap_or(0);
|
||||
|
||||
if content_length > 10_000_000 {
|
||||
return Ok(Box::new(Error::InvalidImage));
|
||||
}
|
||||
|
||||
let mut buf = bytes::BytesMut::with_capacity(content_length);
|
||||
|
||||
while let Some(chunk) = early_return!(resp.chunk().await) {
|
||||
if buf.len() + chunk.len() > 10_000_000 {
|
||||
return Ok(Box::new(Error::InvalidImage));
|
||||
}
|
||||
|
||||
buf.put(chunk);
|
||||
}
|
||||
|
||||
drop(_timer);
|
||||
|
||||
let _timer = IMAGE_HASH_DURATION.start_timer();
|
||||
let hash = tokio::task::spawn_blocking(move || {
|
||||
let hasher = fuzzysearch_common::get_hasher();
|
||||
let image = image::load_from_memory(&buf).unwrap();
|
||||
hasher.hash_image(&image)
|
||||
})
|
||||
.instrument(span!(tracing::Level::TRACE, "hashing image"))
|
||||
.await
|
||||
.unwrap();
|
||||
drop(_timer);
|
||||
|
||||
let hash: [u8; 8] = hash.as_bytes().try_into().unwrap();
|
||||
let num = i64::from_be_bytes(hash);
|
||||
|
||||
let results = image_query(db.clone(), tree.clone(), vec![num], 3, Some(hash.to_vec()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let resp = warp::http::Response::builder()
|
||||
.header("x-image-hash", num.to_string())
|
||||
.header("x-rate-limit-total-image", image_remaining.1.to_string())
|
||||
.header(
|
||||
"x-rate-limit-remaining-image",
|
||||
image_remaining.0.to_string(),
|
||||
)
|
||||
.header("x-rate-limit-total-hash", hash_remaining.1.to_string())
|
||||
.header("x-rate-limit-remaining-hash", hash_remaining.0.to_string())
|
||||
.header("content-type", "application/json")
|
||||
.body(serde_json::to_string(&results).unwrap())
|
||||
.unwrap();
|
||||
|
||||
Ok(Box::new(resp))
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn handle_rejection(err: Rejection) -> Result<impl Reply, std::convert::Infallible> {
|
||||
pub async fn handle_rejection(err: Rejection) -> Result<Box<dyn Reply>, std::convert::Infallible> {
|
||||
warn!("had rejection");
|
||||
|
||||
UNHANDLED_REJECTIONS.inc();
|
||||
|
||||
let (code, message) = if err.is_not_found() {
|
||||
(
|
||||
warp::http::StatusCode::NOT_FOUND,
|
||||
"This page does not exist",
|
||||
)
|
||||
} else if let Some(err) = err.find::<Error>() {
|
||||
match err {
|
||||
Error::BB8(_inner) => (
|
||||
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"A database error occured",
|
||||
),
|
||||
Error::Postgres(_inner) => (
|
||||
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"A database error occured",
|
||||
),
|
||||
Error::InvalidData => (
|
||||
warp::http::StatusCode::BAD_REQUEST,
|
||||
"Unable to operate on provided data",
|
||||
),
|
||||
Error::ApiKey => (
|
||||
warp::http::StatusCode::UNAUTHORIZED,
|
||||
"Invalid API key provided",
|
||||
),
|
||||
Error::RateLimit => (
|
||||
warp::http::StatusCode::TOO_MANY_REQUESTS,
|
||||
"Your API token is rate limited",
|
||||
),
|
||||
}
|
||||
} else if err.find::<warp::reject::InvalidQuery>().is_some() {
|
||||
return Ok(Box::new(Error::InvalidData) as Box<dyn Reply>);
|
||||
} else if err.find::<warp::reject::MethodNotAllowed>().is_some() {
|
||||
return Ok(Box::new(Error::InvalidData) as Box<dyn Reply>);
|
||||
} else {
|
||||
(
|
||||
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
@ -381,5 +555,5 @@ pub async fn handle_rejection(err: Rejection) -> Result<impl Reply, std::convert
|
||||
message: message.into(),
|
||||
});
|
||||
|
||||
Ok(warp::reply::with_status(json, code))
|
||||
Ok(Box::new(warp::reply::with_status(json, code)))
|
||||
}
|
||||
|
@ -1,8 +1,8 @@
|
||||
#![recursion_limit = "256"]
|
||||
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use warp::Filter;
|
||||
|
||||
mod filters;
|
||||
mod handlers;
|
||||
@ -10,7 +10,65 @@ mod models;
|
||||
mod types;
|
||||
mod utils;
|
||||
|
||||
use warp::Filter;
|
||||
type Tree = Arc<RwLock<bk_tree::BKTree<Node, Hamming>>>;
|
||||
type Pool = sqlx::PgPool;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Node {
|
||||
id: i32,
|
||||
hash: [u8; 8],
|
||||
}
|
||||
|
||||
impl Node {
|
||||
pub fn query(hash: [u8; 8]) -> Self {
|
||||
Self { id: -1, hash }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Hamming;
|
||||
|
||||
impl bk_tree::Metric<Node> for Hamming {
|
||||
fn distance(&self, a: &Node, b: &Node) -> u64 {
|
||||
hamming::distance_fast(&a.hash, &b.hash).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
configure_tracing();
|
||||
|
||||
ffmpeg_next::init().expect("Unable to initialize ffmpeg");
|
||||
|
||||
let s = std::env::var("DATABASE_URL").expect("Missing DATABASE_URL");
|
||||
|
||||
let db_pool = sqlx::PgPool::connect(&s)
|
||||
.await
|
||||
.expect("Unable to create Postgres pool");
|
||||
|
||||
serve_metrics().await;
|
||||
|
||||
let tree: Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(Hamming)));
|
||||
|
||||
load_updates(db_pool.clone(), tree.clone()).await;
|
||||
|
||||
let log = warp::log("fuzzysearch");
|
||||
let cors = warp::cors()
|
||||
.allow_any_origin()
|
||||
.allow_headers(vec!["x-api-key"])
|
||||
.allow_methods(vec!["GET", "POST"]);
|
||||
|
||||
let options = warp::options().map(|| "✓");
|
||||
|
||||
let api = options.or(filters::search(db_pool, tree));
|
||||
let routes = api
|
||||
.or(warp::path::end()
|
||||
.map(|| warp::redirect(warp::http::Uri::from_static("https://fuzzysearch.net"))))
|
||||
.with(log)
|
||||
.with(cors)
|
||||
.recover(handlers::handle_rejection);
|
||||
|
||||
warp::serve(routes).run(([0, 0, 0, 0], 8080)).await;
|
||||
}
|
||||
|
||||
fn configure_tracing() {
|
||||
use opentelemetry::{
|
||||
@ -65,133 +123,102 @@ fn configure_tracing() {
|
||||
registry.init();
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Node {
|
||||
async fn metrics(
|
||||
_: hyper::Request<hyper::Body>,
|
||||
) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
|
||||
use hyper::{Body, Response};
|
||||
use prometheus::{Encoder, TextEncoder};
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
let encoder = TextEncoder::new();
|
||||
|
||||
let metric_families = prometheus::gather();
|
||||
encoder.encode(&metric_families, &mut buffer).unwrap();
|
||||
|
||||
Ok(Response::new(Body::from(buffer)))
|
||||
}
|
||||
|
||||
async fn serve_metrics() {
|
||||
use hyper::{
|
||||
service::{make_service_fn, service_fn},
|
||||
Server,
|
||||
};
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
let make_svc = make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn(metrics)) });
|
||||
|
||||
let addr: SocketAddr = std::env::var("METRICS_HOST")
|
||||
.expect("Missing METRICS_HOST")
|
||||
.parse()
|
||||
.expect("Invalid METRICS_HOST");
|
||||
|
||||
let server = Server::bind(&addr).serve(make_svc);
|
||||
|
||||
tokio::spawn(async move {
|
||||
server.await.expect("Metrics server error");
|
||||
});
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct HashRow {
|
||||
id: i32,
|
||||
hash: [u8; 8],
|
||||
hash: i64,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
pub fn query(hash: [u8; 8]) -> Self {
|
||||
Self { id: -1, hash }
|
||||
async fn create_tree(conn: &Pool) -> bk_tree::BKTree<Node, Hamming> {
|
||||
use futures::TryStreamExt;
|
||||
|
||||
let mut tree = bk_tree::BKTree::new(Hamming);
|
||||
|
||||
let mut rows = sqlx::query_as!(HashRow, "SELECT id, hash FROM hashes").fetch(conn);
|
||||
|
||||
while let Some(row) = rows.try_next().await.expect("Unable to get row") {
|
||||
tree.add(Node {
|
||||
id: row.id,
|
||||
hash: row.hash.to_be_bytes(),
|
||||
})
|
||||
}
|
||||
|
||||
tree
|
||||
}
|
||||
|
||||
type Tree = Arc<RwLock<bk_tree::BKTree<Node, Hamming>>>;
|
||||
|
||||
pub struct Hamming;
|
||||
|
||||
impl bk_tree::Metric<Node> for Hamming {
|
||||
fn distance(&self, a: &Node, b: &Node) -> u64 {
|
||||
hamming::distance_fast(&a.hash, &b.hash).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
ffmpeg_next::init().expect("Unable to initialize ffmpeg");
|
||||
|
||||
configure_tracing();
|
||||
|
||||
let s = std::env::var("POSTGRES_DSN").expect("Missing POSTGRES_DSN");
|
||||
|
||||
let manager = bb8_postgres::PostgresConnectionManager::new(
|
||||
tokio_postgres::Config::from_str(&s).expect("Invalid POSTGRES_DSN"),
|
||||
tokio_postgres::NoTls,
|
||||
);
|
||||
|
||||
let db_pool = bb8::Pool::builder()
|
||||
.build(manager)
|
||||
async fn load_updates(conn: Pool, tree: Tree) {
|
||||
let mut listener = sqlx::postgres::PgListener::connect_with(&conn)
|
||||
.await
|
||||
.expect("Unable to build Postgres pool");
|
||||
.unwrap();
|
||||
listener.listen("fuzzysearch_hash_added").await.unwrap();
|
||||
|
||||
let tree: Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(Hamming)));
|
||||
|
||||
let mut max_id = 0;
|
||||
|
||||
let conn = db_pool.get().await.unwrap();
|
||||
let new_tree = create_tree(&conn).await;
|
||||
let mut lock = tree.write().await;
|
||||
conn.query("SELECT id, hash FROM hashes", &[])
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.for_each(|row| {
|
||||
let id: i32 = row.get(0);
|
||||
let hash: i64 = row.get(1);
|
||||
let bytes = hash.to_be_bytes();
|
||||
*lock = new_tree;
|
||||
drop(lock);
|
||||
|
||||
if id > max_id {
|
||||
max_id = id;
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
while let Some(notification) = listener
|
||||
.try_recv()
|
||||
.await
|
||||
.expect("Unable to recv notification")
|
||||
{
|
||||
let payload: HashRow = serde_json::from_str(notification.payload()).unwrap();
|
||||
tracing::debug!(id = payload.id, "Adding new hash to tree");
|
||||
|
||||
let mut lock = tree.write().await;
|
||||
lock.add(Node {
|
||||
id: payload.id,
|
||||
hash: payload.hash.to_be_bytes(),
|
||||
});
|
||||
drop(lock);
|
||||
}
|
||||
|
||||
lock.add(Node { id, hash: bytes });
|
||||
});
|
||||
drop(lock);
|
||||
drop(conn);
|
||||
|
||||
let tree_clone = tree.clone();
|
||||
let pool_clone = db_pool.clone();
|
||||
tokio::spawn(async move {
|
||||
use futures_util::StreamExt;
|
||||
|
||||
let max_id = std::sync::atomic::AtomicI32::new(max_id);
|
||||
let tree = tree_clone;
|
||||
let pool = pool_clone;
|
||||
|
||||
let order = std::sync::atomic::Ordering::SeqCst;
|
||||
|
||||
let interval = tokio::time::interval(std::time::Duration::from_secs(30));
|
||||
|
||||
interval
|
||||
.for_each(|_| async {
|
||||
tracing::debug!("Refreshing hashes");
|
||||
|
||||
let conn = pool.get().await.unwrap();
|
||||
let mut lock = tree.write().await;
|
||||
let id = max_id.load(order);
|
||||
|
||||
let mut count = 0;
|
||||
|
||||
conn.query("SELECT id, hash FROM hashes WHERE hashes.id > $1", &[&id])
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.for_each(|row| {
|
||||
let id: i32 = row.get(0);
|
||||
let hash: i64 = row.get(1);
|
||||
let bytes = hash.to_be_bytes();
|
||||
|
||||
if id > max_id.load(order) {
|
||||
max_id.store(id, order);
|
||||
}
|
||||
|
||||
lock.add(Node { id, hash: bytes });
|
||||
|
||||
count += 1;
|
||||
});
|
||||
|
||||
tracing::trace!("Added {} new hashes", count);
|
||||
})
|
||||
.await;
|
||||
tracing::error!("Lost connection to Postgres, recreating tree");
|
||||
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
|
||||
let new_tree = create_tree(&conn).await;
|
||||
let mut lock = tree.write().await;
|
||||
*lock = new_tree;
|
||||
drop(lock);
|
||||
tracing::info!("Replaced tree");
|
||||
}
|
||||
});
|
||||
|
||||
let log = warp::log("fuzzysearch");
|
||||
let cors = warp::cors()
|
||||
.allow_any_origin()
|
||||
.allow_headers(vec!["x-api-key"])
|
||||
.allow_methods(vec!["GET", "POST"]);
|
||||
|
||||
let options = warp::options().map(|| "✓");
|
||||
|
||||
let api = options.or(filters::search(db_pool, tree));
|
||||
let routes = api
|
||||
.or(warp::path::end()
|
||||
.map(|| warp::redirect(warp::http::Uri::from_static("https://fuzzysearch.net"))))
|
||||
.with(log)
|
||||
.with(cors)
|
||||
.recover(handlers::handle_rejection);
|
||||
|
||||
warp::serve(routes).run(([0, 0, 0, 0], 8080)).await;
|
||||
}
|
||||
|
||||
type Pool = bb8::Pool<bb8_postgres::PostgresConnectionManager<tokio_postgres::NoTls>>;
|
||||
|
@ -1,46 +1,48 @@
|
||||
use crate::types::*;
|
||||
use crate::utils::extract_rows;
|
||||
use crate::{Pool, Tree};
|
||||
use lazy_static::lazy_static;
|
||||
use prometheus::{register_histogram, Histogram};
|
||||
use tracing_futures::Instrument;
|
||||
|
||||
use fuzzysearch_common::types::SearchResult;
|
||||
use crate::types::*;
|
||||
use crate::{Pool, Tree};
|
||||
use fuzzysearch_common::types::{SearchResult, SiteInfo};
|
||||
|
||||
pub type DB<'a> =
|
||||
&'a bb8::PooledConnection<'a, bb8_postgres::PostgresConnectionManager<tokio_postgres::NoTls>>;
|
||||
lazy_static! {
|
||||
static ref IMAGE_LOOKUP_DURATION: Histogram = register_histogram!(
|
||||
"fuzzysearch_api_image_lookup_seconds",
|
||||
"Duration to perform an image lookup"
|
||||
)
|
||||
.unwrap();
|
||||
static ref IMAGE_QUERY_DURATION: Histogram = register_histogram!(
|
||||
"fuzzysearch_api_image_query_seconds",
|
||||
"Duration to perform a single image lookup query"
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(db))]
|
||||
pub async fn lookup_api_key(key: &str, db: DB<'_>) -> Option<ApiKey> {
|
||||
let rows = db
|
||||
.query(
|
||||
"SELECT
|
||||
pub async fn lookup_api_key(key: &str, db: &sqlx::PgPool) -> Option<ApiKey> {
|
||||
sqlx::query_as!(
|
||||
ApiKey,
|
||||
"SELECT
|
||||
api_key.id,
|
||||
api_key.name_limit,
|
||||
api_key.image_limit,
|
||||
api_key.hash_limit,
|
||||
api_key.name,
|
||||
account.email
|
||||
account.email owner_email
|
||||
FROM
|
||||
api_key
|
||||
JOIN account
|
||||
ON account.id = api_key.user_id
|
||||
WHERE
|
||||
api_key.key = $1",
|
||||
&[&key],
|
||||
)
|
||||
.await
|
||||
.expect("Unable to query API keys");
|
||||
|
||||
match rows.into_iter().next() {
|
||||
Some(row) => Some(ApiKey {
|
||||
id: row.get(0),
|
||||
name_limit: row.get(1),
|
||||
image_limit: row.get(2),
|
||||
hash_limit: row.get(3),
|
||||
name: row.get(4),
|
||||
owner_email: row.get(5),
|
||||
}),
|
||||
_ => None,
|
||||
}
|
||||
api_key.key = $1
|
||||
",
|
||||
key
|
||||
)
|
||||
.fetch_optional(db)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(pool, tree))]
|
||||
@ -50,7 +52,7 @@ pub async fn image_query(
|
||||
hashes: Vec<i64>,
|
||||
distance: i64,
|
||||
hash: Option<Vec<u8>>,
|
||||
) -> Result<Vec<SearchResult>, tokio_postgres::Error> {
|
||||
) -> Result<Vec<SearchResult>, sqlx::Error> {
|
||||
let mut results = image_query_sync(pool, tree, hashes, distance, hash);
|
||||
let mut matches = Vec::new();
|
||||
|
||||
@ -68,19 +70,30 @@ pub fn image_query_sync(
|
||||
hashes: Vec<i64>,
|
||||
distance: i64,
|
||||
hash: Option<Vec<u8>>,
|
||||
) -> tokio::sync::mpsc::Receiver<Result<Vec<SearchResult>, tokio_postgres::Error>> {
|
||||
let (mut tx, rx) = tokio::sync::mpsc::channel(50);
|
||||
) -> tokio::sync::mpsc::Receiver<Result<Vec<SearchResult>, sqlx::Error>> {
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(50);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let db = pool.get().await.unwrap();
|
||||
let db = pool;
|
||||
|
||||
for query_hash in hashes {
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
|
||||
let _timer = IMAGE_LOOKUP_DURATION.start_timer();
|
||||
|
||||
let node = crate::Node::query(query_hash.to_be_bytes());
|
||||
let lock = tree.read().await;
|
||||
let items = lock.find(&node, distance as u64);
|
||||
|
||||
for (_dist, item) in items {
|
||||
let query = db.query("SELECT
|
||||
for (dist, item) in items {
|
||||
if seen.contains(&item.id) {
|
||||
continue;
|
||||
}
|
||||
seen.insert(item.id);
|
||||
|
||||
let _timer = IMAGE_QUERY_DURATION.start_timer();
|
||||
|
||||
let row = sqlx::query!("SELECT
|
||||
hashes.id,
|
||||
hashes.hash,
|
||||
hashes.furaffinity_id,
|
||||
@ -106,7 +119,16 @@ pub fn image_query_sync(
|
||||
END file_id,
|
||||
CASE
|
||||
WHEN e621_id IS NOT NULL THEN ARRAY(SELECT jsonb_array_elements_text(e.data->'sources'))
|
||||
END sources
|
||||
END sources,
|
||||
CASE
|
||||
WHEN furaffinity_id IS NOT NULL THEN (f.rating)
|
||||
WHEN e621_id IS NOT NULL THEN (e.data->>'rating')
|
||||
WHEN twitter_id IS NOT NULL THEN
|
||||
CASE
|
||||
WHEN (tw.data->'possibly_sensitive')::boolean IS true THEN 'adult'
|
||||
WHEN (tw.data->'possibly_sensitive')::boolean IS false THEN 'general'
|
||||
END
|
||||
END rating
|
||||
FROM
|
||||
hashes
|
||||
LEFT JOIN LATERAL (
|
||||
@ -133,14 +155,45 @@ pub fn image_query_sync(
|
||||
tweet_media.hash <@ (hashes.hash, 0)
|
||||
LIMIT 1
|
||||
) tm ON hashes.twitter_id IS NOT NULL
|
||||
WHERE hashes.id = $1", &[&item.id]).await;
|
||||
let rows = query.map(|rows| {
|
||||
extract_rows(rows, hash.as_deref()).into_iter().map(|mut file| {
|
||||
file.searched_hash = Some(query_hash);
|
||||
file
|
||||
}).collect()
|
||||
});
|
||||
tx.send(rows).await.unwrap();
|
||||
WHERE hashes.id = $1", item.id).map(|row| {
|
||||
let (site_id, site_info) = if let Some(fa_id) = row.furaffinity_id {
|
||||
(
|
||||
fa_id as i64,
|
||||
Some(SiteInfo::FurAffinity {
|
||||
file_id: row.file_id.unwrap(),
|
||||
})
|
||||
)
|
||||
} else if let Some(e621_id) = row.e621_id {
|
||||
(
|
||||
e621_id as i64,
|
||||
Some(SiteInfo::E621 {
|
||||
sources: row.sources,
|
||||
})
|
||||
)
|
||||
} else if let Some(twitter_id) = row.twitter_id {
|
||||
(twitter_id, Some(SiteInfo::Twitter))
|
||||
} else {
|
||||
(-1, None)
|
||||
};
|
||||
|
||||
let file = SearchResult {
|
||||
id: row.id,
|
||||
site_id,
|
||||
site_info,
|
||||
rating: row.rating.and_then(|rating| rating.parse().ok()),
|
||||
site_id_str: site_id.to_string(),
|
||||
url: row.url.unwrap_or_default(),
|
||||
hash: Some(row.hash),
|
||||
distance: Some(dist),
|
||||
artists: row.artists,
|
||||
filename: row.filename.unwrap_or_default(),
|
||||
searched_hash: Some(query_hash),
|
||||
};
|
||||
|
||||
vec![file]
|
||||
}).fetch_one(&db).await;
|
||||
|
||||
tx.send(row).await.unwrap();
|
||||
}
|
||||
}
|
||||
}.in_current_span());
|
||||
|
@ -10,7 +10,7 @@ use fuzzysearch_common::types::SearchResult;
|
||||
pub struct ApiKey {
|
||||
pub id: i32,
|
||||
pub name: Option<String>,
|
||||
pub owner_email: Option<String>,
|
||||
pub owner_email: String,
|
||||
pub name_limit: i16,
|
||||
pub image_limit: i16,
|
||||
pub hash_limit: i16,
|
||||
@ -22,7 +22,7 @@ pub enum RateLimit {
|
||||
/// This key is limited, we should deny the request.
|
||||
Limited,
|
||||
/// This key is available, contains the number of requests made.
|
||||
Available(i16),
|
||||
Available((i16, i16)),
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@ -68,3 +68,8 @@ pub struct HashSearchOpts {
|
||||
pub struct HandleOpts {
|
||||
pub twitter: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UrlSearchOpts {
|
||||
pub url: String,
|
||||
}
|
||||
|
@ -1,7 +1,15 @@
|
||||
use crate::models::DB;
|
||||
use crate::types::*;
|
||||
use lazy_static::lazy_static;
|
||||
use prometheus::{register_int_counter_vec, IntCounterVec};
|
||||
|
||||
use fuzzysearch_common::types::{SearchResult, SiteInfo};
|
||||
lazy_static! {
|
||||
pub static ref RATE_LIMIT_STATUS: IntCounterVec = register_int_counter_vec!(
|
||||
"fuzzysearch_api_rate_limit_count",
|
||||
"Number of allowed and rate limited requests",
|
||||
&["status"]
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! rate_limit {
|
||||
@ -9,18 +17,48 @@ macro_rules! rate_limit {
|
||||
rate_limit!($api_key, $db, $limit, $group, 1)
|
||||
};
|
||||
|
||||
($api_key:expr, $db:expr, $limit:tt, $group:expr, $incr_by:expr) => {
|
||||
let api_key = crate::models::lookup_api_key($api_key, $db)
|
||||
.await
|
||||
.ok_or_else(|| warp::reject::custom(Error::ApiKey))?;
|
||||
($api_key:expr, $db:expr, $limit:tt, $group:expr, $incr_by:expr) => {{
|
||||
let api_key = match crate::models::lookup_api_key($api_key, $db).await {
|
||||
Some(api_key) => api_key,
|
||||
None => return Ok(Box::new(Error::ApiKey)),
|
||||
};
|
||||
|
||||
let rate_limit =
|
||||
crate::utils::update_rate_limit($db, api_key.id, api_key.$limit, $group, $incr_by)
|
||||
.await
|
||||
.map_err(crate::handlers::map_postgres_err)?;
|
||||
let rate_limit = match crate::utils::update_rate_limit(
|
||||
$db,
|
||||
api_key.id,
|
||||
api_key.$limit,
|
||||
$group,
|
||||
$incr_by,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(rate_limit) => rate_limit,
|
||||
Err(err) => return Ok(Box::new(Error::Postgres(err))),
|
||||
};
|
||||
|
||||
if rate_limit == crate::types::RateLimit::Limited {
|
||||
return Err(warp::reject::custom(Error::RateLimit));
|
||||
match rate_limit {
|
||||
crate::types::RateLimit::Limited => {
|
||||
crate::utils::RATE_LIMIT_STATUS
|
||||
.with_label_values(&["limited"])
|
||||
.inc();
|
||||
return Ok(Box::new(Error::RateLimit));
|
||||
}
|
||||
crate::types::RateLimit::Available(count) => {
|
||||
crate::utils::RATE_LIMIT_STATUS
|
||||
.with_label_values(&["allowed"])
|
||||
.inc();
|
||||
count
|
||||
}
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! early_return {
|
||||
($val:expr) => {
|
||||
match $val {
|
||||
Ok(val) => val,
|
||||
Err(err) => return Ok(Box::new(Error::from(err))),
|
||||
}
|
||||
};
|
||||
}
|
||||
@ -33,85 +71,38 @@ macro_rules! rate_limit {
|
||||
/// joined requests.
|
||||
#[tracing::instrument(skip(db))]
|
||||
pub async fn update_rate_limit(
|
||||
db: DB<'_>,
|
||||
db: &sqlx::PgPool,
|
||||
key_id: i32,
|
||||
key_group_limit: i16,
|
||||
group_name: &'static str,
|
||||
incr_by: i16,
|
||||
) -> Result<RateLimit, tokio_postgres::Error> {
|
||||
) -> Result<RateLimit, sqlx::Error> {
|
||||
let now = chrono::Utc::now();
|
||||
let timestamp = now.timestamp();
|
||||
let time_window = timestamp - (timestamp % 60);
|
||||
|
||||
let rows = db
|
||||
.query(
|
||||
"INSERT INTO
|
||||
rate_limit (api_key_id, time_window, group_name, count)
|
||||
VALUES
|
||||
($1, $2, $3, $4)
|
||||
ON CONFLICT ON CONSTRAINT unique_window
|
||||
DO UPDATE set count = rate_limit.count + $4
|
||||
RETURNING rate_limit.count",
|
||||
&[&key_id, &time_window, &group_name, &incr_by],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let count: i16 = rows[0].get(0);
|
||||
let count: i16 = sqlx::query_scalar!(
|
||||
"INSERT INTO
|
||||
rate_limit (api_key_id, time_window, group_name, count)
|
||||
VALUES
|
||||
($1, $2, $3, $4)
|
||||
ON CONFLICT ON CONSTRAINT unique_window
|
||||
DO UPDATE set count = rate_limit.count + $4
|
||||
RETURNING rate_limit.count",
|
||||
key_id,
|
||||
time_window,
|
||||
group_name,
|
||||
incr_by
|
||||
)
|
||||
.fetch_one(db)
|
||||
.await?;
|
||||
|
||||
if count > key_group_limit {
|
||||
Ok(RateLimit::Limited)
|
||||
} else {
|
||||
Ok(RateLimit::Available(count))
|
||||
Ok(RateLimit::Available((
|
||||
key_group_limit - count,
|
||||
key_group_limit,
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_rows<'a>(
|
||||
rows: Vec<tokio_postgres::Row>,
|
||||
hash: Option<&'a [u8]>,
|
||||
) -> impl IntoIterator<Item = SearchResult> + 'a {
|
||||
rows.into_iter().map(move |row| {
|
||||
let dbhash: i64 = row.get("hash");
|
||||
let dbbytes = dbhash.to_be_bytes();
|
||||
|
||||
let (furaffinity_id, e621_id, twitter_id): (Option<i32>, Option<i32>, Option<i64>) = (
|
||||
row.get("furaffinity_id"),
|
||||
row.get("e621_id"),
|
||||
row.get("twitter_id"),
|
||||
);
|
||||
|
||||
let (site_id, site_info) = if let Some(fa_id) = furaffinity_id {
|
||||
(
|
||||
fa_id as i64,
|
||||
Some(SiteInfo::FurAffinity {
|
||||
file_id: row.get("file_id"),
|
||||
}),
|
||||
)
|
||||
} else if let Some(e6_id) = e621_id {
|
||||
(
|
||||
e6_id as i64,
|
||||
Some(SiteInfo::E621 {
|
||||
sources: row.get("sources"),
|
||||
}),
|
||||
)
|
||||
} else if let Some(t_id) = twitter_id {
|
||||
(t_id, Some(SiteInfo::Twitter))
|
||||
} else {
|
||||
(-1, None)
|
||||
};
|
||||
|
||||
SearchResult {
|
||||
id: row.get("id"),
|
||||
site_id,
|
||||
site_info,
|
||||
site_id_str: site_id.to_string(),
|
||||
url: row.get("url"),
|
||||
hash: Some(dbhash),
|
||||
distance: hash
|
||||
.map(|hash| hamming::distance_fast(&dbbytes, &hash).ok())
|
||||
.flatten(),
|
||||
artists: row.get("artists"),
|
||||
filename: row.get("filename"),
|
||||
searched_hash: None,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user