Initial commit.

This commit is contained in:
Syfaro 2021-07-27 22:25:53 -04:00
commit 9f13ce1939
8 changed files with 3273 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

37
.gitlab-ci.yml Normal file
View File

@ -0,0 +1,37 @@
stages:
- build
- image
variables:
CARGO_HOME: "$CI_PROJECT_DIR/.cargo"
build:
image: rust:1.53-slim-buster
stage: build
artifacts:
expire_in: 1 day
paths:
- ./bkapi
cache:
- key:
files:
- Cargo.lock
paths:
- target/
- .cargo/
script:
- cargo build --release --verbose
- mv ./target/release/bkapi ./bkapi
docker:
image:
name: gcr.io/kaniko-project/executor:debug
entrypoint: [""]
stage: image
needs:
- build
before_script:
- mkdir -p /kaniko/.docker
- echo "{\"auths\":{\"$CI_REGISTRY\":{\"auth\":\"$(echo -n ${CI_REGISTRY_USER}:${CI_REGISTRY_PASSWORD} | base64)\"}}}" > /kaniko/.docker/config.json
script:
- /kaniko/executor --context $CI_PROJECT_DIR --dockerfile $CI_PROJECT_DIR/Dockerfile --destination $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA --destination $CI_REGISTRY_IMAGE:latest --cache=true

2717
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

37
Cargo.toml Normal file
View File

@ -0,0 +1,37 @@
[package]
name = "bkapi"
version = "0.1.0"
authors = ["Syfaro <syfaro@huefox.com>"]
edition = "2018"
[dependencies]
envconfig = "0.10"
thiserror = "1"
tracing = "0.1"
tracing-subscriber = "0.2"
tracing-unwrap = "0.9"
tracing-opentelemetry = "0.14"
opentelemetry = { version = "0.15", features = ["rt-async-std"] }
opentelemetry-semantic-conventions = "0.7.0"
opentelemetry-jaeger = "0.14"
lazy_static = "1"
prometheus = { version = "0.12", features = ["process"] }
bk-tree = "0.4.0"
hamming = "0.1"
futures = "0.3"
async-std = { version = "1", features = ["attributes"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tide = "0.16"
tide-tracing = "0.0.11"
[dependencies.sqlx]
version = "0.5"
features = ["runtime-async-std-native-tls", "postgres"]

4
Dockerfile Normal file
View File

@ -0,0 +1,4 @@
FROM debian:buster-slim
EXPOSE 3000
COPY ./bkapi /bin/bkapi
CMD ["/bin/bkapi"]

10
README.md Normal file
View File

@ -0,0 +1,10 @@
# bkapi
A fast way to look up hamming distance hashes.
It operates by connecting to a PostgreSQL database (`DATABASE_URL`), selecting
every column of a provided query (`DATABASE_QUERY`), subscribing to events
(`DATABASE_SUBSCRIBE`), and holding everything in a BK tree.
It provides a single API endpoint, `/search` which takes in a `hash` and
`distance` query parameter to search for matches.

328
src/main.rs Normal file
View File

@ -0,0 +1,328 @@
use std::sync::Arc;
use async_std::sync::{RwLock, RwLockUpgradableReadGuard};
use envconfig::Envconfig;
use opentelemetry::KeyValue;
use sqlx::{
postgres::{PgListener, PgPoolOptions},
Pool, Postgres, Row,
};
use tide::Request;
use tracing_subscriber::layer::SubscriberExt;
use tracing_unwrap::ResultExt;
mod middlewares;
lazy_static::lazy_static! {
static ref HTTP_REQUEST_COUNT: prometheus::CounterVec = prometheus::register_counter_vec!("http_requests_total", "Number of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap();
static ref HTTP_REQUEST_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("http_request_duration_seconds", "Duration of HTTP requests", &["http_route", "http_method", "http_status_code"]).unwrap();
static ref TREE_DURATION: prometheus::HistogramVec = prometheus::register_histogram_vec!("bkapi_tree_duration_seconds", "Duration of tree search time", &["distance"]).unwrap();
}
#[derive(thiserror::Error, Debug)]
enum Error {
#[error("row was unable to be loaded: {0}")]
LoadingRow(sqlx::Error),
#[error("listener could not listen: {0}")]
Listener(sqlx::Error),
#[error("listener got data that could not be decoded: {0}")]
Data(serde_json::Error),
}
type Tree = Arc<RwLock<bk_tree::BKTree<Node, Hamming>>>;
#[derive(Envconfig, Clone)]
struct Config {
#[envconfig(default = "0.0.0.0:3000")]
http_listen: String,
#[envconfig(default = "127.0.0.1:6831")]
jaeger_agent: String,
#[envconfig(default = "bkapi")]
service_name: String,
database_url: String,
database_query: String,
database_subscribe: String,
#[envconfig(default = "false")]
database_is_unique: bool,
max_distance: Option<u32>,
}
/// A hamming distance metric.
struct Hamming;
impl bk_tree::Metric<Node> for Hamming {
fn distance(&self, a: &Node, b: &Node) -> u32 {
hamming::distance_fast(&a.0, &b.0).expect_or_log("hashes did not have same byte alignment")
as u32
}
fn threshold_distance(&self, a: &Node, b: &Node, _threshold: u32) -> Option<u32> {
Some(self.distance(a, b))
}
}
/// A value of a node in the BK tree.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct Node([u8; 8]);
impl From<i64> for Node {
fn from(num: i64) -> Self {
Self(num.to_be_bytes())
}
}
impl From<Node> for i64 {
fn from(node: Node) -> Self {
i64::from_be_bytes(node.0)
}
}
#[async_std::main]
async fn main() {
let config = Config::init_from_env().expect("could not load config");
opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new());
let env = std::env::var("ENVIRONMENT");
let env = if let Ok(env) = env.as_ref() {
env.as_str()
} else if cfg!(debug_assertions) {
"debug"
} else {
"release"
};
let tracer = opentelemetry_jaeger::new_pipeline()
.with_agent_endpoint(&config.jaeger_agent)
.with_service_name(&config.service_name)
.with_tags(vec![
KeyValue::new("environment", env.to_owned()),
KeyValue::new("version", env!("CARGO_PKG_VERSION")),
])
.install_batch(opentelemetry::runtime::AsyncStd)
.expect("otel jaeger pipeline could not be created");
let trace = tracing_opentelemetry::layer().with_tracer(tracer.clone());
tracing::subscriber::set_global_default(
tracing_subscriber::Registry::default()
.with(tracing_subscriber::EnvFilter::from_default_env())
.with(trace)
.with(tracing_subscriber::fmt::layer()),
)
.expect("tracing could not be configured");
tracing::info!("starting bkbase");
tracing::debug!("loaded config");
let tree: Tree = Arc::new(RwLock::new(bk_tree::BKTree::new(Hamming)));
let pool = PgPoolOptions::new()
.max_connections(2)
.connect(&config.database_url)
.await
.expect_or_log("could not connect to database");
tracing::debug!("connected to postgres");
let http_listen = config.http_listen.clone();
let max_distance = config.max_distance;
let (sender, receiver) = futures::channel::oneshot::channel();
tracing::info!("starting to listen for payloads");
let tree_clone = tree.clone();
async_std::task::spawn(async {
listen_for_payloads(pool, config, tree_clone, sender)
.await
.expect_or_log("listenting for updates failed");
});
tracing::info!("waiting for initial tree to load");
receiver
.await
.expect_or_log("tree loading was dropped before completing");
tracing::info!("initial tree loaded, starting server");
let mut app = tide::with_state(State { tree, max_distance });
app.with(middlewares::TideOpentelemMiddleware::new(tracer));
app.with(tide_tracing::TraceMiddleware::new());
app.with(middlewares::TidePrometheusMiddleware);
app.at("/search").get(search);
app.listen(&http_listen)
.await
.expect_or_log("could not start web server");
}
#[derive(Clone)]
struct State {
tree: Tree,
max_distance: Option<u32>,
}
#[derive(serde::Deserialize)]
struct Query {
hash: i64,
distance: u32,
}
#[derive(serde::Serialize)]
struct HashDistance {
hash: i64,
distance: u32,
}
#[derive(serde::Serialize)]
struct SearchResponse {
hash: i64,
distance: u32,
hashes: Vec<HashDistance>,
}
#[tracing::instrument(skip(req))]
async fn search(req: Request<State>) -> tide::Result {
let state = req.state();
let Query { hash, distance } = req.query()?;
tracing::info!("searching for hash {} with distance {}", hash, distance);
if matches!(state.max_distance, Some(max_distance) if distance > max_distance) {
return Err(tide::Error::from_str(
400,
"Distance is greater than max distance",
));
}
let tree = state.tree.read().await;
let duration = TREE_DURATION
.with_label_values(&[&distance.to_string()])
.start_timer();
let matches: Vec<HashDistance> = tree
.find(&hash.into(), distance)
.into_iter()
.map(|item| HashDistance {
distance: item.0,
hash: (*item.1).into(),
})
.collect();
let time = duration.stop_and_record();
tracing::debug!("found {} items in {} seconds", matches.len(), time);
let resp = SearchResponse {
hash,
distance,
hashes: matches,
};
Ok(serde_json::to_string(&resp)?.into())
}
/// Create a new BK tree and pull in all hashes from provided query.
///
/// This must be called after you have started a listener, otherwise items may
/// be lost.
async fn create_tree(
conn: &Pool<Postgres>,
config: &Config,
) -> Result<bk_tree::BKTree<Node, Hamming>, Error> {
use futures::TryStreamExt;
tracing::warn!("creating new tree");
let mut tree = bk_tree::BKTree::new(Hamming);
let mut rows = sqlx::query(&config.database_query).fetch(conn);
let mut count = 0;
let start = std::time::Instant::now();
while let Some(row) = rows.try_next().await.map_err(Error::LoadingRow)? {
let node: Node = row.get::<i64, _>(0).into();
// Avoid checking if each value is unique if we were told that the
// database query only returns unique values.
if config.database_is_unique || tree.find_exact(&node).is_none() {
tree.add(node);
}
count += 1;
if count % 250_000 == 0 {
tracing::debug!(count, "loaded more rows");
}
}
let dur = std::time::Instant::now().duration_since(start);
tracing::info!(count, "completed loading rows in {:?}", dur);
Ok(tree)
}
#[derive(serde::Deserialize)]
struct Payload {
hash: i64,
}
/// Listen for incoming payloads.
///
/// This will create a new tree to ensure all items are present. It will also
/// automatically recreate trees as needed if the database connection is lost.
async fn listen_for_payloads(
conn: Pool<Postgres>,
config: Config,
tree: Tree,
initial: futures::channel::oneshot::Sender<()>,
) -> Result<(), Error> {
let mut listener = PgListener::connect_with(&conn)
.await
.map_err(Error::Listener)?;
listener
.listen(&config.database_subscribe)
.await
.map_err(Error::Listener)?;
let new_tree = create_tree(&conn, &config).await?;
{
let mut tree = tree.write().await;
*tree = new_tree;
}
initial
.send(())
.expect_or_log("nothing listening for initial data");
loop {
while let Some(notification) = listener.try_recv().await.map_err(Error::Listener)? {
let payload: Payload =
serde_json::from_str(notification.payload()).map_err(Error::Data)?;
tracing::debug!(hash = payload.hash, "evaluating new payload");
let node: Node = payload.hash.into();
let tree = tree.upgradable_read().await;
if tree.find_exact(&node).is_some() {
tracing::trace!("hash already existed in tree");
continue;
}
tracing::trace!("hash did not exist, adding to tree");
let mut tree = RwLockUpgradableReadGuard::upgrade(tree).await;
tree.add(node);
}
tracing::error!("disconnected from listener, recreating tree");
async_std::task::sleep(std::time::Duration::from_secs(10)).await;
let new_tree = create_tree(&conn, &config).await?;
{
let mut tree = tree.write().await;
*tree = new_tree;
}
}
}

139
src/middlewares.rs Normal file
View File

@ -0,0 +1,139 @@
use std::collections::HashMap;
use std::convert::TryFrom;
use opentelemetry::{
global::get_text_map_propagator,
trace::{FutureExt, Span, SpanKind, TraceContextExt, Tracer},
Context,
};
use opentelemetry_semantic_conventions::trace;
use prometheus::{Encoder, TextEncoder};
use tide::{
http::{
headers::{HeaderName, HeaderValue},
mime,
},
Middleware, Request, Response,
};
pub struct TidePrometheusMiddleware;
impl TidePrometheusMiddleware {
const ROUTE: &'static str = "/metrics";
}
#[tide::utils::async_trait]
impl<State: Clone + Send + Sync + 'static> Middleware<State> for TidePrometheusMiddleware {
async fn handle(&self, req: Request<State>, next: tide::Next<'_, State>) -> tide::Result {
let path = req.url().path().to_owned();
if path == Self::ROUTE {
let mut buffer = Vec::new();
let encoder = TextEncoder::new();
let metric_families = prometheus::gather();
encoder.encode(&metric_families, &mut buffer).unwrap();
return Ok(Response::builder(200)
.body(buffer)
.content_type(mime::PLAIN)
.build());
}
let method = req.method().to_string();
let start = std::time::Instant::now();
let res = next.run(req).await;
let end = std::time::Instant::now().duration_since(start);
let status_code = res.status().to_string();
let labels: Vec<&str> = vec![&path, &method, &status_code];
crate::HTTP_REQUEST_COUNT.with_label_values(&labels).inc();
crate::HTTP_REQUEST_DURATION
.with_label_values(&labels)
.observe(end.as_secs_f64());
Ok(res)
}
}
pub struct TideOpentelemMiddleware<T: Tracer> {
tracer: T,
}
impl<T: Tracer> TideOpentelemMiddleware<T> {
pub fn new(tracer: T) -> Self {
Self { tracer }
}
}
#[tide::utils::async_trait]
impl<T: Tracer + Send + Sync, State: Clone + Send + Sync + 'static> Middleware<State>
for TideOpentelemMiddleware<T>
{
async fn handle(&self, req: Request<State>, next: tide::Next<'_, State>) -> tide::Result {
let parent_cx = get_parent_cx(&req);
let method = req.method().to_string();
let url = req.url();
let attributes = vec![
trace::HTTP_METHOD.string(method.clone()),
trace::HTTP_SCHEME.string(url.scheme().to_string()),
trace::HTTP_URL.string(url.to_string()),
];
let mut span_builder = self
.tracer
.span_builder(format!("{} {}", method, url.path()))
.with_kind(SpanKind::Server)
.with_attributes(attributes);
if parent_cx.span().span_context().is_remote() {
tracing::trace!("incoming request has remote span: {:?}", parent_cx);
span_builder = span_builder.with_parent_context(parent_cx);
}
let mut span = span_builder.start(&self.tracer);
span.add_event("request.started".to_owned(), vec![]);
let cx = &Context::current_with_span(span);
let mut res = next.run(req).with_context(cx.clone()).await;
let span = cx.span();
span.add_event("request.completed".to_owned(), vec![]);
span.set_attribute(trace::HTTP_STATUS_CODE.i64(u16::from(res.status()).into()));
if let Some(len) = res.len().and_then(|len| i64::try_from(len).ok()) {
span.set_attribute(trace::HTTP_RESPONSE_CONTENT_LENGTH.i64(len));
}
let mut injector = HashMap::new();
get_text_map_propagator(|propagator| propagator.inject_context(&cx, &mut injector));
for (key, value) in injector {
let header_name = HeaderName::from_bytes(key.into_bytes());
let header_value = HeaderValue::from_bytes(value.into_bytes());
if let (Ok(name), Ok(value)) = (header_name, header_value) {
res.insert_header(name, value);
} else {
tracing::error!("injected header data was invalid");
}
}
Ok(res)
}
}
fn get_parent_cx<State>(req: &Request<State>) -> Context {
let mut req_headers = HashMap::new();
for (key, value) in req.iter() {
req_headers.insert(key.to_string(), value.last().to_string());
}
get_text_map_propagator(|propagator| propagator.extract(&req_headers))
}