mirror of
https://github.com/Syfaro/fuzzysearch.git
synced 2024-11-23 15:22:31 +00:00
Initial commit.
This commit is contained in:
commit
82dd1ef8b7
1
.dockerignore
Normal file
1
.dockerignore
Normal file
@ -0,0 +1 @@
|
|||||||
|
target/
|
26
.drone.yml
Normal file
26
.drone.yml
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
---
|
||||||
|
kind: pipeline
|
||||||
|
type: docker
|
||||||
|
name: default
|
||||||
|
|
||||||
|
platform:
|
||||||
|
os: linux
|
||||||
|
arch: amd64
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: docker
|
||||||
|
image: plugins/docker
|
||||||
|
settings:
|
||||||
|
auto_tag: true
|
||||||
|
password:
|
||||||
|
from_secret: docker_password
|
||||||
|
registry: registry.huefox.com
|
||||||
|
repo: registry.huefox.com/fuzzysearch
|
||||||
|
username:
|
||||||
|
from_secret: docker_username
|
||||||
|
|
||||||
|
trigger:
|
||||||
|
branch:
|
||||||
|
- master
|
||||||
|
|
||||||
|
...
|
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
/target
|
2055
Cargo.lock
generated
Normal file
2055
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
26
Cargo.toml
Normal file
26
Cargo.toml
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
[package]
|
||||||
|
name = "fuzzysearch"
|
||||||
|
version = "0.1.0"
|
||||||
|
authors = ["Syfaro <syfaro@huefox.com>"]
|
||||||
|
edition = "2018"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
log = "0.4.8"
|
||||||
|
pretty_env_logger = "0.3.1"
|
||||||
|
|
||||||
|
tokio = { version = "0.2.9", features = ["full"] }
|
||||||
|
chrono = "0.4.10"
|
||||||
|
futures = "0.3.1"
|
||||||
|
futures-util = "0.3.1"
|
||||||
|
bytes = "0.5.3"
|
||||||
|
|
||||||
|
serde = { version = "1.0.104", features = ["derive"] }
|
||||||
|
warp = { git = "https://github.com/seanmonstar/warp.git" }
|
||||||
|
|
||||||
|
tokio-postgres = "0.5.1"
|
||||||
|
bb8 = { git = "https://github.com/khuey/bb8.git" }
|
||||||
|
bb8-postgres = { git = "https://github.com/khuey/bb8.git" }
|
||||||
|
|
||||||
|
img_hash = "3.0.0"
|
||||||
|
image = "0.22"
|
||||||
|
hamming = "0.1.3"
|
10
Dockerfile
Normal file
10
Dockerfile
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
FROM rustlang/rust:nightly-slim AS builder
|
||||||
|
WORKDIR /src
|
||||||
|
COPY . .
|
||||||
|
RUN cargo install --root / --path .
|
||||||
|
|
||||||
|
FROM debian:buster-slim
|
||||||
|
EXPOSE 8080
|
||||||
|
WORKDIR /app
|
||||||
|
COPY --from=builder /bin/fuzzysearch /bin/fuzzysearch
|
||||||
|
CMD ["/bin/fuzzysearch"]
|
34
src/filters.rs
Normal file
34
src/filters.rs
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
use crate::types::*;
|
||||||
|
use crate::{handlers, Pool};
|
||||||
|
use std::convert::Infallible;
|
||||||
|
use warp::{Filter, Rejection, Reply};
|
||||||
|
|
||||||
|
pub fn search(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
|
||||||
|
search_file(db.clone()).or(search_image(db))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn search_file(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
|
||||||
|
warp::path("file")
|
||||||
|
.and(warp::get())
|
||||||
|
.and(warp::query::<FileSearchOpts>())
|
||||||
|
.and(with_pool(db))
|
||||||
|
.and(with_api_key())
|
||||||
|
.and_then(handlers::search_file)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn search_image(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
|
||||||
|
warp::path("image")
|
||||||
|
.and(warp::post())
|
||||||
|
.and(warp::multipart::form().max_length(1024 * 1024 * 10))
|
||||||
|
.and(with_pool(db))
|
||||||
|
.and(with_api_key())
|
||||||
|
.and_then(handlers::search_image)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn with_api_key() -> impl Filter<Extract = (String,), Error = Rejection> + Clone {
|
||||||
|
warp::header::<String>("x-api-key")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn with_pool(db: Pool) -> impl Filter<Extract = (Pool,), Error = Infallible> + Clone {
|
||||||
|
warp::any().map(move || db.clone())
|
||||||
|
}
|
251
src/handlers.rs
Normal file
251
src/handlers.rs
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
use crate::types::*;
|
||||||
|
use crate::utils::{extract_e621_rows, extract_fa_rows};
|
||||||
|
use crate::{rate_limit, Pool};
|
||||||
|
use log::debug;
|
||||||
|
use warp::{reject, Rejection, Reply};
|
||||||
|
|
||||||
|
fn map_bb8_err(err: bb8::RunError<tokio_postgres::Error>) -> Rejection {
|
||||||
|
reject::custom(Error::from(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_postgres_err(err: tokio_postgres::Error) -> Rejection {
|
||||||
|
reject::custom(Error::from(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum Error {
|
||||||
|
BB8(bb8::RunError<tokio_postgres::Error>),
|
||||||
|
Postgres(tokio_postgres::Error),
|
||||||
|
InvalidData,
|
||||||
|
ApiKey,
|
||||||
|
RateLimit,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<bb8::RunError<tokio_postgres::Error>> for Error {
|
||||||
|
fn from(err: bb8::RunError<tokio_postgres::Error>) -> Self {
|
||||||
|
Error::BB8(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<tokio_postgres::Error> for Error {
|
||||||
|
fn from(err: tokio_postgres::Error) -> Self {
|
||||||
|
Error::Postgres(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl warp::reject::Reject for Error {}
|
||||||
|
|
||||||
|
pub async fn search_image(
|
||||||
|
form: warp::multipart::FormData,
|
||||||
|
db: Pool,
|
||||||
|
api_key: String,
|
||||||
|
) -> Result<impl Reply, Rejection> {
|
||||||
|
let db = db.get().await.map_err(map_bb8_err)?;
|
||||||
|
|
||||||
|
rate_limit!(&api_key, &db, image_limit, "image");
|
||||||
|
|
||||||
|
use bytes::BufMut;
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
let parts: Vec<_> = form.collect().await;
|
||||||
|
let mut parts = parts
|
||||||
|
.into_iter()
|
||||||
|
.map(|part| {
|
||||||
|
let part = part.unwrap();
|
||||||
|
(part.name().to_string(), part)
|
||||||
|
})
|
||||||
|
.collect::<std::collections::HashMap<_, _>>();
|
||||||
|
let image = parts.remove("image").unwrap();
|
||||||
|
|
||||||
|
let bytes = image
|
||||||
|
.stream()
|
||||||
|
.fold(bytes::BytesMut::new(), |mut b, data| {
|
||||||
|
b.put(data.unwrap());
|
||||||
|
async move { b }
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let hash = {
|
||||||
|
let hasher = crate::get_hasher();
|
||||||
|
let image = image::load_from_memory(&bytes).unwrap();
|
||||||
|
hasher.hash_image(&image)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut buf: [u8; 8] = [0; 8];
|
||||||
|
buf.copy_from_slice(&hash.as_bytes());
|
||||||
|
|
||||||
|
let num = i64::from_be_bytes(buf);
|
||||||
|
|
||||||
|
debug!("Matching hash {}", num);
|
||||||
|
|
||||||
|
let params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = vec![&num];
|
||||||
|
|
||||||
|
let fa = db.query(
|
||||||
|
"SELECT
|
||||||
|
submission.id,
|
||||||
|
submission.url,
|
||||||
|
submission.filename,
|
||||||
|
submission.file_id,
|
||||||
|
submission.hash,
|
||||||
|
submission.hash_int,
|
||||||
|
artist.name
|
||||||
|
FROM
|
||||||
|
submission
|
||||||
|
JOIN artist
|
||||||
|
ON artist.id = submission.artist_id
|
||||||
|
WHERE
|
||||||
|
hash_int <@ ($1, 10)",
|
||||||
|
¶ms,
|
||||||
|
);
|
||||||
|
|
||||||
|
let e621 = db.query(
|
||||||
|
"SELECT
|
||||||
|
e621.id,
|
||||||
|
e621.hash,
|
||||||
|
e621.data->>'file_url' url,
|
||||||
|
e621.data->>'md5' md5,
|
||||||
|
sources.list sources,
|
||||||
|
artists.list artists,
|
||||||
|
(e621.data->>'md5') || '.' || (e621.data->>'file_ext') filename
|
||||||
|
FROM
|
||||||
|
e621,
|
||||||
|
LATERAL (
|
||||||
|
SELECT array_agg(s) list
|
||||||
|
FROM jsonb_array_elements_text(data->'sources') s
|
||||||
|
) sources,
|
||||||
|
LATERAL (
|
||||||
|
SELECT array_agg(s) list
|
||||||
|
FROM jsonb_array_elements_text(data->'artist') s
|
||||||
|
) artists
|
||||||
|
WHERE
|
||||||
|
hash <@ ($1, 10)",
|
||||||
|
¶ms,
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = futures::future::join(fa, e621).await;
|
||||||
|
let (fa_results, e621_results) = (results.0.unwrap(), results.1.unwrap());
|
||||||
|
|
||||||
|
let mut items = Vec::with_capacity(fa_results.len() + e621_results.len());
|
||||||
|
|
||||||
|
items.extend(extract_fa_rows(fa_results, Some(&hash.as_bytes())));
|
||||||
|
items.extend(extract_e621_rows(e621_results, Some(&hash.as_bytes())));
|
||||||
|
|
||||||
|
items.sort_by(|a, b| {
|
||||||
|
a.distance
|
||||||
|
.unwrap_or(u64::max_value())
|
||||||
|
.partial_cmp(&b.distance.unwrap_or(u64::max_value()))
|
||||||
|
.unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
|
let similarity = ImageSimilarity {
|
||||||
|
hash: num,
|
||||||
|
matches: items,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(warp::reply::json(&similarity))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn search_file(
|
||||||
|
opts: FileSearchOpts,
|
||||||
|
db: Pool,
|
||||||
|
api_key: String,
|
||||||
|
) -> Result<impl Reply, Rejection> {
|
||||||
|
let db = db.get().await.map_err(map_bb8_err)?;
|
||||||
|
|
||||||
|
rate_limit!(&api_key, &db, name_limit, "file");
|
||||||
|
|
||||||
|
let (filter, val): (&'static str, &(dyn tokio_postgres::types::ToSql + Sync)) =
|
||||||
|
if let Some(ref id) = opts.id {
|
||||||
|
("file_id = $1", id)
|
||||||
|
} else if let Some(ref name) = opts.name {
|
||||||
|
("lower(filename) = lower($1)", name)
|
||||||
|
} else if let Some(ref url) = opts.url {
|
||||||
|
("lower(url) = lower($1)", url)
|
||||||
|
} else {
|
||||||
|
return Err(warp::reject::custom(Error::InvalidData));
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!("Searching for {:?}", opts);
|
||||||
|
|
||||||
|
let query = format!(
|
||||||
|
"SELECT
|
||||||
|
submission.id,
|
||||||
|
submission.url,
|
||||||
|
submission.filename,
|
||||||
|
submission.file_id,
|
||||||
|
artist.name
|
||||||
|
FROM
|
||||||
|
submission
|
||||||
|
JOIN artist
|
||||||
|
ON artist.id = submission.artist_id
|
||||||
|
WHERE
|
||||||
|
{}
|
||||||
|
LIMIT 10",
|
||||||
|
filter
|
||||||
|
);
|
||||||
|
|
||||||
|
let matches: Vec<_> = db
|
||||||
|
.query::<str>(&*query, &[val])
|
||||||
|
.await
|
||||||
|
.map_err(map_postgres_err)?
|
||||||
|
.into_iter()
|
||||||
|
.map(|row| File {
|
||||||
|
id: row.get("id"),
|
||||||
|
url: row.get("url"),
|
||||||
|
filename: row.get("filename"),
|
||||||
|
artists: row
|
||||||
|
.get::<&str, Option<String>>("name")
|
||||||
|
.map(|artist| vec![artist]),
|
||||||
|
distance: None,
|
||||||
|
hash: None,
|
||||||
|
site_info: Some(SiteInfo::FurAffinity(FurAffinityFile {
|
||||||
|
file_id: row.get("file_id"),
|
||||||
|
})),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(warp::reply::json(&matches))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn handle_rejection(err: Rejection) -> Result<impl Reply, std::convert::Infallible> {
|
||||||
|
let (code, message) = if err.is_not_found() {
|
||||||
|
(
|
||||||
|
warp::http::StatusCode::NOT_FOUND,
|
||||||
|
"This page does not exist",
|
||||||
|
)
|
||||||
|
} else if let Some(err) = err.find::<Error>() {
|
||||||
|
match err {
|
||||||
|
Error::BB8(_inner) => (
|
||||||
|
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
"A database error occured",
|
||||||
|
),
|
||||||
|
Error::Postgres(_inner) => (
|
||||||
|
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
"A database error occured",
|
||||||
|
),
|
||||||
|
Error::InvalidData => (
|
||||||
|
warp::http::StatusCode::BAD_REQUEST,
|
||||||
|
"Unable to operate on provided data",
|
||||||
|
),
|
||||||
|
Error::ApiKey => (
|
||||||
|
warp::http::StatusCode::UNAUTHORIZED,
|
||||||
|
"Invalid API key provided",
|
||||||
|
),
|
||||||
|
Error::RateLimit => (
|
||||||
|
warp::http::StatusCode::TOO_MANY_REQUESTS,
|
||||||
|
"Your API token is rate limited",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
"An unknown error occured",
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = warp::reply::json(&ErrorMessage {
|
||||||
|
code: code.as_u16(),
|
||||||
|
message: message.into(),
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(warp::reply::with_status(json, code))
|
||||||
|
}
|
53
src/main.rs
Normal file
53
src/main.rs
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
mod filters;
|
||||||
|
mod handlers;
|
||||||
|
mod models;
|
||||||
|
mod types;
|
||||||
|
mod utils;
|
||||||
|
|
||||||
|
use warp::Filter;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
pretty_env_logger::init();
|
||||||
|
|
||||||
|
let s = std::env::var("POSTGRES_DSN").expect("Missing POSTGRES_DSN");
|
||||||
|
|
||||||
|
let manager = bb8_postgres::PostgresConnectionManager::new(
|
||||||
|
tokio_postgres::Config::from_str(&s).expect("Invalid POSTGRES_DSN"),
|
||||||
|
tokio_postgres::NoTls,
|
||||||
|
);
|
||||||
|
|
||||||
|
let db_pool = bb8::Pool::builder()
|
||||||
|
.build(manager)
|
||||||
|
.await
|
||||||
|
.expect("Unable to build Postgres pool");
|
||||||
|
|
||||||
|
let log = warp::log("fuzzysearch");
|
||||||
|
let cors = warp::cors()
|
||||||
|
.allow_any_origin()
|
||||||
|
.allow_methods(vec!["GET", "POST"]);
|
||||||
|
|
||||||
|
let api = filters::search(db_pool);
|
||||||
|
let routes = api
|
||||||
|
.or(warp::path::end()
|
||||||
|
.map(|| warp::redirect(warp::http::Uri::from_static("https://fuzzysearch.net"))))
|
||||||
|
.with(log)
|
||||||
|
.with(cors)
|
||||||
|
.recover(handlers::handle_rejection);
|
||||||
|
|
||||||
|
warp::serve(routes).run(([0, 0, 0, 0], 8080)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
type Pool = bb8::Pool<bb8_postgres::PostgresConnectionManager<tokio_postgres::NoTls>>;
|
||||||
|
|
||||||
|
fn get_hasher() -> img_hash::Hasher {
|
||||||
|
use img_hash::{HashAlg::Gradient, HasherConfig};
|
||||||
|
|
||||||
|
HasherConfig::new()
|
||||||
|
.hash_alg(Gradient)
|
||||||
|
.hash_size(8, 8)
|
||||||
|
.preproc_dct()
|
||||||
|
.to_hasher()
|
||||||
|
}
|
36
src/models.rs
Normal file
36
src/models.rs
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
use crate::types::*;
|
||||||
|
|
||||||
|
pub type DB<'a> =
|
||||||
|
&'a bb8::PooledConnection<'a, bb8_postgres::PostgresConnectionManager<tokio_postgres::NoTls>>;
|
||||||
|
|
||||||
|
pub async fn lookup_api_key(key: &str, db: DB<'_>) -> Option<ApiKey> {
|
||||||
|
let rows = db
|
||||||
|
.query(
|
||||||
|
"SELECT
|
||||||
|
api_key.id,
|
||||||
|
api_key.name_limit,
|
||||||
|
api_key.image_limit,
|
||||||
|
api_key.name,
|
||||||
|
account.email
|
||||||
|
FROM
|
||||||
|
api_key
|
||||||
|
JOIN account
|
||||||
|
ON account.id = api_key.user_id
|
||||||
|
WHERE
|
||||||
|
api_key.key = $1",
|
||||||
|
&[&key],
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("Unable to query API keys");
|
||||||
|
|
||||||
|
match rows.into_iter().next() {
|
||||||
|
Some(row) => Some(ApiKey {
|
||||||
|
id: row.get(0),
|
||||||
|
name_limit: row.get(1),
|
||||||
|
image_limit: row.get(2),
|
||||||
|
name: row.get(3),
|
||||||
|
owner_email: row.get(4),
|
||||||
|
}),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
81
src/types.rs
Normal file
81
src/types.rs
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// An API key representation from the database.alloc
|
||||||
|
///
|
||||||
|
/// May contain information about the owner, always has rate limit information.
|
||||||
|
/// Limits are the number of requests allowed per minute.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ApiKey {
|
||||||
|
pub id: i32,
|
||||||
|
pub name: Option<String>,
|
||||||
|
pub owner_email: Option<String>,
|
||||||
|
pub name_limit: i16,
|
||||||
|
pub image_limit: i16,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The status of an API key's rate limit.
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
|
pub enum RateLimit {
|
||||||
|
/// This key is limited, we should deny the request.
|
||||||
|
Limited,
|
||||||
|
/// This key is available, contains the number of requests made.
|
||||||
|
Available(i16),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A general type for every file.
|
||||||
|
#[derive(Debug, Default, Serialize)]
|
||||||
|
pub struct File {
|
||||||
|
pub id: i32,
|
||||||
|
pub url: String,
|
||||||
|
pub filename: String,
|
||||||
|
pub artists: Option<Vec<String>>,
|
||||||
|
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub site_info: Option<SiteInfo>,
|
||||||
|
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub hash: Option<i64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub distance: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
#[serde(tag = "site", content = "site_info")]
|
||||||
|
pub enum SiteInfo {
|
||||||
|
FurAffinity(FurAffinityFile),
|
||||||
|
#[serde(rename = "e621")]
|
||||||
|
E621(E621File),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Information about a file hosted on FurAffinity.
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct FurAffinityFile {
|
||||||
|
pub file_id: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Information about a file hosted on e621.
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct E621File {
|
||||||
|
pub file_md5: String,
|
||||||
|
pub sources: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct FileSearchOpts {
|
||||||
|
pub id: Option<i32>,
|
||||||
|
pub name: Option<String>,
|
||||||
|
pub url: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct ImageSimilarity {
|
||||||
|
pub hash: i64,
|
||||||
|
pub matches: Vec<File>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub struct ErrorMessage {
|
||||||
|
pub code: u16,
|
||||||
|
pub message: String,
|
||||||
|
}
|
117
src/utils.rs
Normal file
117
src/utils.rs
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
use crate::models::DB;
|
||||||
|
use crate::types::*;
|
||||||
|
use log::debug;
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! rate_limit {
|
||||||
|
($api_key:expr, $db:expr, $limit:tt, $group:expr) => {
|
||||||
|
rate_limit!($api_key, $db, $limit, $group, 1)
|
||||||
|
};
|
||||||
|
|
||||||
|
($api_key:expr, $db:expr, $limit:tt, $group:expr, $incr_by:expr) => {
|
||||||
|
let api_key = crate::models::lookup_api_key($api_key, $db)
|
||||||
|
.await
|
||||||
|
.ok_or_else(|| warp::reject::custom(Error::ApiKey))?;
|
||||||
|
|
||||||
|
let rate_limit =
|
||||||
|
crate::utils::update_rate_limit($db, api_key.id, api_key.$limit, $group, $incr_by)
|
||||||
|
.await
|
||||||
|
.map_err(crate::handlers::map_postgres_err)?;
|
||||||
|
|
||||||
|
if rate_limit == crate::types::RateLimit::Limited {
|
||||||
|
return Err(warp::reject::custom(Error::RateLimit));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Increment the rate limit for a group.
|
||||||
|
///
|
||||||
|
/// We need to specify the ID of the API key to increment, the key's limit for
|
||||||
|
/// the specified group, the name of the group we're incrementing, and the
|
||||||
|
/// amount to increment for this request. This should remain as 1 except for
|
||||||
|
/// joined requests.
|
||||||
|
pub async fn update_rate_limit(
|
||||||
|
db: DB<'_>,
|
||||||
|
key_id: i32,
|
||||||
|
key_group_limit: i16,
|
||||||
|
group_name: &'static str,
|
||||||
|
incr_by: i16,
|
||||||
|
) -> Result<RateLimit, tokio_postgres::Error> {
|
||||||
|
let now = chrono::Utc::now();
|
||||||
|
let timestamp = now.timestamp();
|
||||||
|
let time_window = timestamp - (timestamp % 60);
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Incrementing rate limit for: {}-{} by {}",
|
||||||
|
key_id, group_name, incr_by
|
||||||
|
);
|
||||||
|
|
||||||
|
let rows = db
|
||||||
|
.query(
|
||||||
|
"INSERT INTO
|
||||||
|
rate_limit (api_key_id, time_window, group_name, count)
|
||||||
|
VALUES
|
||||||
|
($1, $2, $3, $4)
|
||||||
|
ON CONFLICT ON CONSTRAINT unique_window
|
||||||
|
DO UPDATE set count = rate_limit.count + $4
|
||||||
|
RETURNING rate_limit.count",
|
||||||
|
&[&key_id, &time_window, &group_name, &incr_by],
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let count: i16 = rows[0].get(0);
|
||||||
|
|
||||||
|
if count > key_group_limit {
|
||||||
|
Ok(RateLimit::Limited)
|
||||||
|
} else {
|
||||||
|
Ok(RateLimit::Available(count))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn extract_fa_rows<'a>(
|
||||||
|
rows: Vec<tokio_postgres::Row>,
|
||||||
|
hash: Option<&'a [u8]>,
|
||||||
|
) -> impl IntoIterator<Item = File> + 'a {
|
||||||
|
rows.into_iter().map(move |row| {
|
||||||
|
let dbbytes: Vec<u8> = row.get("hash");
|
||||||
|
|
||||||
|
File {
|
||||||
|
id: row.get("id"),
|
||||||
|
url: row.get("url"),
|
||||||
|
filename: row.get("filename"),
|
||||||
|
hash: row.get("hash_int"),
|
||||||
|
distance: hash
|
||||||
|
.map(|hash| hamming::distance_fast(&dbbytes, &hash).ok())
|
||||||
|
.flatten(),
|
||||||
|
site_info: Some(SiteInfo::FurAffinity(FurAffinityFile {
|
||||||
|
file_id: row.get("file_id"),
|
||||||
|
})),
|
||||||
|
artists: row.get::<&str, Option<String>>("name").map(|row| vec![row]),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn extract_e621_rows<'a>(
|
||||||
|
rows: Vec<tokio_postgres::Row>,
|
||||||
|
hash: Option<&'a [u8]>,
|
||||||
|
) -> impl IntoIterator<Item = File> + 'a {
|
||||||
|
rows.into_iter().map(move |row| {
|
||||||
|
let dbhash: i64 = row.get("hash");
|
||||||
|
let dbbytes = dbhash.to_be_bytes();
|
||||||
|
|
||||||
|
File {
|
||||||
|
id: row.get("id"),
|
||||||
|
url: row.get("url"),
|
||||||
|
hash: Some(dbhash),
|
||||||
|
distance: hash
|
||||||
|
.map(|hash| hamming::distance_fast(&dbbytes, &hash).ok())
|
||||||
|
.flatten(),
|
||||||
|
site_info: Some(SiteInfo::E621(E621File {
|
||||||
|
file_md5: row.get("md5"),
|
||||||
|
sources: row.get("sources"),
|
||||||
|
})),
|
||||||
|
artists: row.get("artists"),
|
||||||
|
filename: row.get("filename"),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user