mod endpoint;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
mod tls;
pub use endpoint::Endpoint;
#[cfg(feature = "tls")]
pub use tls::ClientTlsConfig;
use super::service::{Connection, DynamicServiceStream, SharedExec};
use crate::body::BoxBody;
use crate::transport::Executor;
use bytes::Bytes;
use http::{
uri::{InvalidUri, Uri},
Request, Response,
};
use hyper::client::connect::Connection as HyperConnection;
use std::{
fmt,
future::Future,
hash::Hash,
pin::Pin,
task::{Context, Poll},
};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::mpsc::{channel, Sender},
};
use tower::balance::p2c::Balance;
use tower::{
buffer::{self, Buffer},
discover::{Change, Discover},
util::{BoxService, Either},
Service,
};
type Svc = Either<Connection, BoxService<Request<BoxBody>, Response<hyper::Body>, crate::Error>>;
const DEFAULT_BUFFER_SIZE: usize = 1024;
#[derive(Clone)]
pub struct Channel {
svc: Buffer<Svc, Request<BoxBody>>,
}
pub struct ResponseFuture {
inner: buffer::future::ResponseFuture<<Svc as Service<Request<BoxBody>>>::Future>,
}
impl Channel {
pub fn builder(uri: Uri) -> Endpoint {
Endpoint::from(uri)
}
pub fn from_static(s: &'static str) -> Endpoint {
let uri = Uri::from_static(s);
Self::builder(uri)
}
pub fn from_shared(s: impl Into<Bytes>) -> Result<Endpoint, InvalidUri> {
let uri = Uri::from_maybe_shared(s.into())?;
Ok(Self::builder(uri))
}
pub fn balance_list(list: impl Iterator<Item = Endpoint>) -> Self {
let (channel, tx) = Self::balance_channel(DEFAULT_BUFFER_SIZE);
list.for_each(|endpoint| {
tx.try_send(Change::Insert(endpoint.uri.clone(), endpoint))
.unwrap();
});
channel
}
pub fn balance_channel<K>(capacity: usize) -> (Self, Sender<Change<K, Endpoint>>)
where
K: Hash + Eq + Send + Clone + 'static,
{
Self::balance_channel_with_executor(capacity, SharedExec::tokio())
}
pub fn balance_channel_with_executor<K, E>(
capacity: usize,
executor: E,
) -> (Self, Sender<Change<K, Endpoint>>)
where
K: Hash + Eq + Send + Clone + 'static,
E: Executor<Pin<Box<dyn Future<Output = ()> + Send>>> + Send + Sync + 'static,
{
let (tx, rx) = channel(capacity);
let list = DynamicServiceStream::new(rx);
(Self::balance(list, DEFAULT_BUFFER_SIZE, executor), tx)
}
pub(crate) fn new<C>(connector: C, endpoint: Endpoint) -> Self
where
C: Service<Uri> + Send + 'static,
C::Error: Into<crate::Error> + Send,
C::Future: Unpin + Send,
C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static,
{
let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE);
let executor = endpoint.executor.clone();
let svc = Connection::lazy(connector, endpoint);
let (svc, worker) = Buffer::pair(Either::A(svc), buffer_size);
executor.execute(Box::pin(worker));
Channel { svc }
}
pub(crate) async fn connect<C>(connector: C, endpoint: Endpoint) -> Result<Self, super::Error>
where
C: Service<Uri> + Send + 'static,
C::Error: Into<crate::Error> + Send,
C::Future: Unpin + Send,
C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static,
{
let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE);
let executor = endpoint.executor.clone();
let svc = Connection::connect(connector, endpoint)
.await
.map_err(super::Error::from_source)?;
let (svc, worker) = Buffer::pair(Either::A(svc), buffer_size);
executor.execute(Box::pin(worker));
Ok(Channel { svc })
}
pub(crate) fn balance<D, E>(discover: D, buffer_size: usize, executor: E) -> Self
where
D: Discover<Service = Connection> + Unpin + Send + 'static,
D::Error: Into<crate::Error>,
D::Key: Hash + Send + Clone,
E: Executor<futures_core::future::BoxFuture<'static, ()>> + Send + Sync + 'static,
{
let svc = Balance::new(discover);
let svc = BoxService::new(svc);
let (svc, worker) = Buffer::pair(Either::B(svc), buffer_size);
executor.execute(Box::pin(worker));
Channel { svc }
}
}
impl Service<http::Request<BoxBody>> for Channel {
type Response = http::Response<super::Body>;
type Error = super::Error;
type Future = ResponseFuture;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Service::poll_ready(&mut self.svc, cx).map_err(super::Error::from_source)
}
fn call(&mut self, request: http::Request<BoxBody>) -> Self::Future {
let inner = Service::call(&mut self.svc, request);
ResponseFuture { inner }
}
}
impl Future for ResponseFuture {
type Output = Result<Response<hyper::Body>, super::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let val = futures_util::ready!(Pin::new(&mut self.inner).poll(cx))
.map_err(super::Error::from_source)?;
Ok(val).into()
}
}
impl fmt::Debug for Channel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Channel").finish()
}
}
impl fmt::Debug for ResponseFuture {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ResponseFuture").finish()
}
}