Add ability to search by URL, rate limit headers (#2)

* Initial progress on searching by URL.

* Avoid rejections for error messages.

* Handle some more rejections.

* Fix build issues.

* Remove detailed error messages.

* Fix issue with built Docker image.

* Add rate limit headers to all responses.

* Remove unneeded dependency.

* Limit URLs to 10MB.
This commit is contained in:
Syfaro 2021-01-21 21:21:16 -05:00 committed by GitHub
parent f6319e6d90
commit 3ade5aeba9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 878 additions and 627 deletions

View File

@ -37,4 +37,8 @@ steps:
exclude:
- master
---
kind: signature
hmac: c17d371c096be4b0544aad509cd58fa475e8737ac2c2dd77db5603044e3a5892
...

1073
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -13,19 +13,24 @@ opentelemetry = "0.6"
opentelemetry-jaeger = "0.5"
tracing-opentelemetry = "0.5"
tokio = { version = "0.3", features = ["macros", "rt-multi-thread", "sync"] }
tokio = { version = "1", features = ["full"] }
async-stream = "0.3"
futures = "0.3"
futures-util = "0.3"
chrono = "0.4"
bytes = "0.5"
bytes = "1"
serde = { version = "1", features = ["derive"] }
warp = "0.2"
serde_json = "1"
tokio-postgres = "0.6"
bb8 = "0.6"
bb8-postgres = "0.6"
warp = "0.3"
reqwest = "0.11"
hyper = "0.14"
tokio-postgres = "0.7"
bb8 = "0.7"
bb8-postgres = "0.7"
img_hash = "3"
image = "0.23"

View File

@ -1,10 +1,12 @@
FROM rust:1-slim AS builder
WORKDIR /src
RUN apt-get update -y && apt-get install -y libssl-dev pkg-config
COPY . .
RUN cargo install --root / --path .
FROM debian:buster-slim
EXPOSE 8080
WORKDIR /app
RUN apt-get update -y && apt-get install -y openssl ca-certificates && rm -rf /var/lib/apt/lists/*
COPY --from=builder /bin/fuzzysearch /bin/fuzzysearch
CMD ["/bin/fuzzysearch"]

View File

@ -10,9 +10,10 @@ pub fn search(
) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
search_image(db.clone(), tree.clone())
.or(search_hashes(db.clone(), tree.clone()))
.or(stream_search_image(db.clone(), tree))
.or(stream_search_image(db.clone(), tree.clone()))
.or(search_file(db.clone()))
.or(check_handle(db))
.or(check_handle(db.clone()))
.or(search_image_by_url(db, tree))
}
pub fn search_file(db: Pool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
@ -54,6 +55,19 @@ pub fn search_image(
})
}
pub fn search_image_by_url(
db: Pool,
tree: Tree,
) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
warp::path("url")
.and(warp::get())
.and(warp::query::<URLSearchOpts>())
.and(with_pool(db))
.and(with_tree(tree))
.and(with_api_key())
.and_then(handlers::search_image_by_url)
}
pub fn search_hashes(
db: Pool,
tree: Tree,

View File

@ -1,27 +1,56 @@
use crate::models::{image_query, image_query_sync};
use crate::types::*;
use crate::{rate_limit, Pool, Tree};
use crate::{early_return, rate_limit, Pool, Tree};
use std::convert::TryInto;
use tracing::{span, warn};
use tracing_futures::Instrument;
use warp::{reject, Rejection, Reply};
fn map_bb8_err(err: bb8::RunError<tokio_postgres::Error>) -> Rejection {
reject::custom(Error::from(err))
}
fn map_postgres_err(err: tokio_postgres::Error) -> Rejection {
reject::custom(Error::from(err))
}
use warp::{Rejection, Reply};
#[derive(Debug)]
enum Error {
BB8(bb8::RunError<tokio_postgres::Error>),
Postgres(tokio_postgres::Error),
Reqwest(reqwest::Error),
InvalidData,
InvalidImage,
ApiKey,
RateLimit,
}
impl warp::Reply for Error {
fn into_response(self) -> warp::reply::Response {
let msg = match self {
Error::BB8(_) | Error::Postgres(_) | Error::Reqwest(_) => ErrorMessage {
code: 500,
message: "Internal server error".to_string(),
},
Error::InvalidData => ErrorMessage {
code: 400,
message: "Invalid data provided".to_string(),
},
Error::InvalidImage => ErrorMessage {
code: 400,
message: "Invalid image provided".to_string(),
},
Error::ApiKey => ErrorMessage {
code: 401,
message: "Invalid API key".to_string(),
},
Error::RateLimit => ErrorMessage {
code: 429,
message: "Too many requests".to_string(),
},
};
let body = hyper::body::Body::from(serde_json::to_string(&msg).unwrap());
warp::http::Response::builder()
.status(msg.code)
.body(body)
.unwrap()
}
}
impl From<bb8::RunError<tokio_postgres::Error>> for Error {
fn from(err: bb8::RunError<tokio_postgres::Error>) -> Self {
Error::BB8(err)
@ -34,12 +63,16 @@ impl From<tokio_postgres::Error> for Error {
}
}
impl warp::reject::Reject for Error {}
impl From<reqwest::Error> for Error {
fn from(err: reqwest::Error) -> Self {
Error::Reqwest(err)
}
}
#[tracing::instrument(skip(form))]
async fn hash_input(form: warp::multipart::FormData) -> (i64, img_hash::ImageHash<[u8; 8]>) {
use bytes::BufMut;
use futures_util::StreamExt;
use futures::StreamExt;
let parts: Vec<_> = form.collect().await;
let mut parts = parts
@ -82,11 +115,11 @@ pub async fn search_image(
pool: Pool,
tree: Tree,
api_key: String,
) -> Result<impl Reply, Rejection> {
let db = pool.get().await.map_err(map_bb8_err)?;
) -> Result<Box<dyn Reply>, Rejection> {
let db = early_return!(pool.get().await);
rate_limit!(&api_key, &db, image_limit, "image");
rate_limit!(&api_key, &db, hash_limit, "hash");
let image_remaining = rate_limit!(&api_key, &db, image_limit, "image");
let hash_remaining = rate_limit!(&api_key, &db, hash_limit, "hash");
let (num, hash) = hash_input(form).await;
@ -139,7 +172,20 @@ pub async fn search_image(
matches: items,
};
Ok(warp::reply::json(&similarity))
let resp = warp::http::Response::builder()
.header("x-image-hash", num.to_string())
.header("x-rate-limit-total-image", image_remaining.1.to_string())
.header(
"x-rate-limit-remaining-image",
image_remaining.0.to_string(),
)
.header("x-rate-limit-total-hash", hash_remaining.1.to_string())
.header("x-rate-limit-remaining-hash", hash_remaining.0.to_string())
.header("content-type", "application/json")
.body(serde_json::to_string(&similarity).unwrap())
.unwrap();
Ok(Box::new(resp))
}
pub async fn stream_image(
@ -147,34 +193,38 @@ pub async fn stream_image(
pool: Pool,
tree: Tree,
api_key: String,
) -> Result<impl Reply, Rejection> {
use futures_util::StreamExt;
let db = pool.get().await.map_err(map_bb8_err)?;
) -> Result<Box<dyn Reply>, Rejection> {
let db = early_return!(pool.get().await);
rate_limit!(&api_key, &db, image_limit, "image", 2);
rate_limit!(&api_key, &db, hash_limit, "hash");
let (num, hash) = hash_input(form).await;
let event_stream = image_query_sync(
let mut query = image_query_sync(
pool.clone(),
tree,
vec![num],
10,
Some(hash.as_bytes().to_vec()),
)
.map(sse_matches);
);
Ok(warp::sse::reply(event_stream))
let event_stream = async_stream::stream! {
while let Some(result) = query.recv().await {
yield sse_matches(result);
}
};
Ok(Box::new(warp::sse::reply(event_stream)))
}
#[allow(clippy::unnecessary_wraps)]
fn sse_matches(
matches: Result<Vec<File>, tokio_postgres::Error>,
) -> Result<impl warp::sse::ServerSentEvent, core::convert::Infallible> {
) -> Result<warp::sse::Event, core::convert::Infallible> {
let items = matches.unwrap();
Ok(warp::sse::json(items))
Ok(warp::sse::Event::default().json_data(items).unwrap())
}
pub async fn search_hashes(
@ -182,9 +232,9 @@ pub async fn search_hashes(
db: Pool,
tree: Tree,
api_key: String,
) -> Result<impl Reply, Rejection> {
) -> Result<Box<dyn Reply>, Rejection> {
let pool = db.clone();
let db = db.get().await.map_err(map_bb8_err)?;
let db = early_return!(db.get().await);
let hashes: Vec<i64> = opts
.hashes
@ -194,10 +244,10 @@ pub async fn search_hashes(
.collect();
if hashes.is_empty() {
return Err(warp::reject::custom(Error::InvalidData));
return Ok(Box::new(Error::InvalidData));
}
rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16);
let image_remaining = rate_limit!(&api_key, &db, image_limit, "image", hashes.len() as i16);
let mut results = image_query_sync(
pool,
@ -209,20 +259,30 @@ pub async fn search_hashes(
let mut matches = Vec::new();
while let Some(r) = results.recv().await {
matches.extend(r.map_err(|e| warp::reject::custom(Error::Postgres(e)))?);
matches.extend(early_return!(r));
}
Ok(warp::reply::json(&matches))
let resp = warp::http::Response::builder()
.header("x-rate-limit-total-image", image_remaining.1.to_string())
.header(
"x-rate-limit-remaining-image",
image_remaining.0.to_string(),
)
.header("content-type", "application/json")
.body(serde_json::to_string(&matches).unwrap())
.unwrap();
Ok(Box::new(resp))
}
pub async fn search_file(
opts: FileSearchOpts,
db: Pool,
api_key: String,
) -> Result<impl Reply, Rejection> {
let db = db.get().await.map_err(map_bb8_err)?;
) -> Result<Box<dyn Reply>, Rejection> {
let db = early_return!(db.get().await);
rate_limit!(&api_key, &db, name_limit, "file");
let file_remaining = rate_limit!(&api_key, &db, name_limit, "file");
let (filter, val): (&'static str, &(dyn tokio_postgres::types::ToSql + Sync)) =
if let Some(ref id) = opts.id {
@ -232,7 +292,7 @@ pub async fn search_file(
} else if let Some(ref url) = opts.url {
("lower(url) = lower($1)", url)
} else {
return Err(warp::reject::custom(Error::InvalidData));
return Ok(Box::new(Error::InvalidData));
};
let query = format!(
@ -255,11 +315,11 @@ pub async fn search_file(
filter
);
let matches: Vec<_> = db
.query::<str>(&*query, &[val])
let matches: Vec<_> = early_return!(
db.query::<str>(&*query, &[val])
.instrument(span!(tracing::Level::TRACE, "waiting for db"))
.await
.map_err(map_postgres_err)?
)
.into_iter()
.map(|row| File {
id: row.get("hash_id"),
@ -279,29 +339,119 @@ pub async fn search_file(
})
.collect();
Ok(warp::reply::json(&matches))
let resp = warp::http::Response::builder()
.header("x-rate-limit-total-file", file_remaining.1.to_string())
.header("x-rate-limit-remaining-file", file_remaining.0.to_string())
.header("content-type", "application/json")
.body(serde_json::to_string(&matches).unwrap())
.unwrap();
Ok(Box::new(resp))
}
pub async fn check_handle(opts: HandleOpts, db: Pool) -> Result<impl Reply, Rejection> {
let db = db.get().await.map_err(map_bb8_err)?;
pub async fn check_handle(opts: HandleOpts, db: Pool) -> Result<Box<dyn Reply>, Rejection> {
let db = early_return!(db.get().await);
let exists = if let Some(handle) = opts.twitter {
!db.query(
!early_return!(
db.query(
"SELECT 1 FROM twitter_user WHERE lower(data->>'screen_name') = lower($1)",
&[&handle],
)
.await
.map_err(map_postgres_err)?
)
.is_empty()
} else {
false
};
Ok(warp::reply::json(&exists))
Ok(Box::new(warp::reply::json(&exists)))
}
pub async fn search_image_by_url(
opts: URLSearchOpts,
pool: Pool,
tree: Tree,
api_key: String,
) -> Result<Box<dyn Reply>, Rejection> {
use bytes::BufMut;
let url = opts.url;
let db = early_return!(pool.get().await);
let image_remaining = rate_limit!(&api_key, &db, image_limit, "image");
let hash_remaining = rate_limit!(&api_key, &db, hash_limit, "hash");
let mut resp = match reqwest::get(&url).await {
Ok(resp) => resp,
Err(_err) => return Ok(Box::new(Error::InvalidImage)),
};
let content_length = resp
.headers()
.get("content-length")
.and_then(|len| {
String::from_utf8_lossy(len.as_bytes())
.parse::<usize>()
.ok()
})
.unwrap_or(0);
if content_length > 10_000_000 {
return Ok(Box::new(Error::InvalidImage));
}
let mut buf = bytes::BytesMut::with_capacity(content_length);
while let Some(chunk) = early_return!(resp.chunk().await) {
if buf.len() + chunk.len() > 10_000_000 {
return Ok(Box::new(Error::InvalidImage));
}
buf.put(chunk);
}
let hash = tokio::task::spawn_blocking(move || {
let hasher = crate::get_hasher();
let image = image::load_from_memory(&buf).unwrap();
hasher.hash_image(&image)
})
.instrument(span!(tracing::Level::TRACE, "hashing image"))
.await
.unwrap();
let hash: [u8; 8] = hash.as_bytes().try_into().unwrap();
let num = i64::from_be_bytes(hash);
let results = image_query(
pool.clone(),
tree.clone(),
vec![num],
3,
Some(hash.to_vec()),
)
.await
.unwrap();
let resp = warp::http::Response::builder()
.header("x-image-hash", num.to_string())
.header("x-rate-limit-total-image", image_remaining.1.to_string())
.header(
"x-rate-limit-remaining-image",
image_remaining.0.to_string(),
)
.header("x-rate-limit-total-hash", hash_remaining.1.to_string())
.header("x-rate-limit-remaining-hash", hash_remaining.0.to_string())
.header("content-type", "application/json")
.body(serde_json::to_string(&results).unwrap())
.unwrap();
Ok(Box::new(resp))
}
#[tracing::instrument]
pub async fn handle_rejection(err: Rejection) -> Result<impl Reply, std::convert::Infallible> {
pub async fn handle_rejection(err: Rejection) -> Result<Box<dyn Reply>, std::convert::Infallible> {
warn!("had rejection");
let (code, message) = if err.is_not_found() {
@ -309,29 +459,10 @@ pub async fn handle_rejection(err: Rejection) -> Result<impl Reply, std::convert
warp::http::StatusCode::NOT_FOUND,
"This page does not exist",
)
} else if let Some(err) = err.find::<Error>() {
match err {
Error::BB8(_inner) => (
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
"A database error occured",
),
Error::Postgres(_inner) => (
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
"A database error occured",
),
Error::InvalidData => (
warp::http::StatusCode::BAD_REQUEST,
"Unable to operate on provided data",
),
Error::ApiKey => (
warp::http::StatusCode::UNAUTHORIZED,
"Invalid API key provided",
),
Error::RateLimit => (
warp::http::StatusCode::TOO_MANY_REQUESTS,
"Your API token is rate limited",
),
}
} else if err.find::<warp::reject::InvalidQuery>().is_some() {
return Ok(Box::new(Error::InvalidData) as Box<dyn Reply>);
} else if err.find::<warp::reject::MethodNotAllowed>().is_some() {
return Ok(Box::new(Error::InvalidData) as Box<dyn Reply>);
} else {
(
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
@ -344,5 +475,5 @@ pub async fn handle_rejection(err: Rejection) -> Result<impl Reply, std::convert
message: message.into(),
});
Ok(warp::reply::with_status(json, code))
Ok(Box::new(warp::reply::with_status(json, code)))
}

View File

@ -130,7 +130,7 @@ async fn main() {
let tree_clone = tree.clone();
let pool_clone = db_pool.clone();
tokio::spawn(async move {
use futures_util::StreamExt;
use futures::StreamExt;
let max_id = std::sync::atomic::AtomicI32::new(max_id);
let tree = tree_clone;
@ -138,7 +138,13 @@ async fn main() {
let order = std::sync::atomic::Ordering::SeqCst;
let interval = tokio::time::interval(std::time::Duration::from_secs(30));
let interval = async_stream::stream! {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
while let item = interval.tick().await {
yield item;
}
};
interval
.for_each(|_| async {
@ -148,7 +154,7 @@ async fn main() {
let mut lock = tree.write().await;
let id = max_id.load(order);
let mut count = 0;
let mut count: usize = 0;
conn.query("SELECT id, hash FROM hashes WHERE hashes.id > $1", &[&id])
.await

View File

@ -20,7 +20,7 @@ pub enum RateLimit {
/// This key is limited, we should deny the request.
Limited,
/// This key is available, contains the number of requests made.
Available(i16),
Available((i16, i16)),
}
/// A general type for every file.
@ -112,3 +112,8 @@ pub struct HashSearchOpts {
pub struct HandleOpts {
pub twitter: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct URLSearchOpts {
pub url: String,
}

View File

@ -7,18 +7,38 @@ macro_rules! rate_limit {
rate_limit!($api_key, $db, $limit, $group, 1)
};
($api_key:expr, $db:expr, $limit:tt, $group:expr, $incr_by:expr) => {
let api_key = crate::models::lookup_api_key($api_key, $db)
.await
.ok_or_else(|| warp::reject::custom(Error::ApiKey))?;
($api_key:expr, $db:expr, $limit:tt, $group:expr, $incr_by:expr) => {{
let api_key = match crate::models::lookup_api_key($api_key, $db).await {
Some(api_key) => api_key,
None => return Ok(Box::new(Error::ApiKey)),
};
let rate_limit =
crate::utils::update_rate_limit($db, api_key.id, api_key.$limit, $group, $incr_by)
let rate_limit = match crate::utils::update_rate_limit(
$db,
api_key.id,
api_key.$limit,
$group,
$incr_by,
)
.await
.map_err(crate::handlers::map_postgres_err)?;
{
Ok(rate_limit) => rate_limit,
Err(err) => return Ok(Box::new(Error::Postgres(err))),
};
if rate_limit == crate::types::RateLimit::Limited {
return Err(warp::reject::custom(Error::RateLimit));
match rate_limit {
crate::types::RateLimit::Limited => return Ok(Box::new(Error::RateLimit)),
crate::types::RateLimit::Available(count) => count,
}
}};
}
#[macro_export]
macro_rules! early_return {
($val:expr) => {
match $val {
Ok(val) => val,
Err(err) => return Ok(Box::new(Error::from(err))),
}
};
}
@ -59,14 +79,17 @@ pub async fn update_rate_limit(
if count > key_group_limit {
Ok(RateLimit::Limited)
} else {
Ok(RateLimit::Available(count))
Ok(RateLimit::Available((
key_group_limit - count,
key_group_limit,
)))
}
}
pub fn extract_rows<'a>(
pub fn extract_rows(
rows: Vec<tokio_postgres::Row>,
hash: Option<&'a [u8]>,
) -> impl IntoIterator<Item = File> + 'a {
hash: Option<&[u8]>,
) -> impl IntoIterator<Item = File> + '_ {
rows.into_iter().map(move |row| {
let dbhash: i64 = row.get("hash");
let dbbytes = dbhash.to_be_bytes();