add rust lang to repo by add assets/mini-redis

This commit is contained in:
sunface
2022-02-24 16:14:58 +08:00
parent d0c4a4669a
commit 58cb728626
31 changed files with 5036 additions and 0 deletions

View File

@ -0,0 +1,108 @@
use mini_redis::{client, DEFAULT_PORT};
use bytes::Bytes;
use std::num::ParseIntError;
use std::str;
use std::time::Duration;
use structopt::StructOpt;
#[derive(StructOpt, Debug)]
#[structopt(name = "mini-redis-cli", author = env!("CARGO_PKG_AUTHORS"), about = "Issue Redis commands")]
struct Cli {
#[structopt(subcommand)]
command: Command,
#[structopt(name = "hostname", long = "--host", default_value = "127.0.0.1")]
host: String,
#[structopt(name = "port", long = "--port", default_value = DEFAULT_PORT)]
port: String,
}
#[derive(StructOpt, Debug)]
enum Command {
/// Get the value of key.
Get {
/// Name of key to get
key: String,
},
/// Set key to hold the string value.
Set {
/// Name of key to set
key: String,
/// Value to set.
#[structopt(parse(from_str = bytes_from_str))]
value: Bytes,
/// Expire the value after specified amount of time
#[structopt(parse(try_from_str = duration_from_ms_str))]
expires: Option<Duration>,
},
}
/// Entry point for CLI tool.
///
/// The `[tokio::main]` annotation signals that the Tokio runtime should be
/// started when the function is called. The body of the function is executed
/// within the newly spawned runtime.
///
/// `flavor = "current_thread"` is used here to avoid spawning background
/// threads. The CLI tool use case benefits more by being lighter instead of
/// multi-threaded.
#[tokio::main(flavor = "current_thread")]
async fn main() -> mini_redis::Result<()> {
// Enable logging
tracing_subscriber::fmt::try_init()?;
// Parse command line arguments
let cli = Cli::from_args();
// Get the remote address to connect to
let addr = format!("{}:{}", cli.host, cli.port);
// Establish a connection
let mut client = client::connect(&addr).await?;
// Process the requested command
match cli.command {
Command::Get { key } => {
if let Some(value) = client.get(&key).await? {
if let Ok(string) = str::from_utf8(&value) {
println!("\"{}\"", string);
} else {
println!("{:?}", value);
}
} else {
println!("(nil)");
}
}
Command::Set {
key,
value,
expires: None,
} => {
client.set(&key, value).await?;
println!("OK");
}
Command::Set {
key,
value,
expires: Some(expires),
} => {
client.set_expires(&key, value, expires).await?;
println!("OK");
}
}
Ok(())
}
fn duration_from_ms_str(src: &str) -> Result<Duration, ParseIntError> {
let ms = src.parse::<u64>()?;
Ok(Duration::from_millis(ms))
}
fn bytes_from_str(src: &str) -> Bytes {
Bytes::from(src.to_string())
}

View File

@ -0,0 +1,37 @@
//! mini-redis server.
//!
//! This file is the entry point for the server implemented in the library. It
//! performs command line parsing and passes the arguments on to
//! `mini_redis::server`.
//!
//! The `clap` crate is used for parsing arguments.
use mini_redis::{server, DEFAULT_PORT};
use structopt::StructOpt;
use tokio::net::TcpListener;
use tokio::signal;
#[tokio::main]
pub async fn main() -> mini_redis::Result<()> {
// enable logging
// see https://docs.rs/tracing for more info
tracing_subscriber::fmt::try_init()?;
let cli = Cli::from_args();
let port = cli.port.as_deref().unwrap_or(DEFAULT_PORT);
// Bind a TCP listener
let listener = TcpListener::bind(&format!("127.0.0.1:{}", port)).await?;
server::run(listener, signal::ctrl_c()).await;
Ok(())
}
#[derive(StructOpt, Debug)]
#[structopt(name = "mini-redis-server", version = env!("CARGO_PKG_VERSION"), author = env!("CARGO_PKG_AUTHORS"), about = "A Redis server")]
struct Cli {
#[structopt(name = "port", long = "--port")]
port: Option<String>,
}

View File

@ -0,0 +1,264 @@
//! Minimal blocking Redis client implementation
//!
//! Provides a blocking connect and methods for issuing the supported commands.
use bytes::Bytes;
use std::time::Duration;
use tokio::net::ToSocketAddrs;
use tokio::runtime::Runtime;
pub use crate::client::Message;
/// Established connection with a Redis server.
///
/// Backed by a single `TcpStream`, `BlockingClient` provides basic network
/// client functionality (no pooling, retrying, ...). Connections are
/// established using the [`connect`](fn@connect) function.
///
/// Requests are issued using the various methods of `Client`.
pub struct BlockingClient {
/// The asynchronous `Client`.
inner: crate::client::Client,
/// A `current_thread` runtime for executing operations on the asynchronous
/// client in a blocking manner.
rt: Runtime,
}
/// A client that has entered pub/sub mode.
///
/// Once clients subscribe to a channel, they may only perform pub/sub related
/// commands. The `BlockingClient` type is transitioned to a
/// `BlockingSubscriber` type in order to prevent non-pub/sub methods from being
/// called.
pub struct BlockingSubscriber {
/// The asynchronous `Subscriber`.
inner: crate::client::Subscriber,
/// A `current_thread` runtime for executing operations on the asynchronous
/// `Subscriber` in a blocking manner.
rt: Runtime,
}
/// The iterator returned by `Subscriber::into_iter`.
struct SubscriberIterator {
/// The asynchronous `Subscriber`.
inner: crate::client::Subscriber,
/// A `current_thread` runtime for executing operations on the asynchronous
/// `Subscriber` in a blocking manner.
rt: Runtime,
}
/// Establish a connection with the Redis server located at `addr`.
///
/// `addr` may be any type that can be asynchronously converted to a
/// `SocketAddr`. This includes `SocketAddr` and strings. The `ToSocketAddrs`
/// trait is the Tokio version and not the `std` version.
///
/// # Examples
///
/// ```no_run
/// use mini_redis::blocking_client;
///
/// fn main() {
/// let client = match blocking_client::connect("localhost:6379") {
/// Ok(client) => client,
/// Err(_) => panic!("failed to establish connection"),
/// };
/// # drop(client);
/// }
/// ```
pub fn connect<T: ToSocketAddrs>(addr: T) -> crate::Result<BlockingClient> {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let inner = rt.block_on(crate::client::connect(addr))?;
Ok(BlockingClient { inner, rt })
}
impl BlockingClient {
/// Get the value of key.
///
/// If the key does not exist the special value `None` is returned.
///
/// # Examples
///
/// Demonstrates basic usage.
///
/// ```no_run
/// use mini_redis::blocking_client;
///
/// fn main() {
/// let mut client = blocking_client::connect("localhost:6379").unwrap();
///
/// let val = client.get("foo").unwrap();
/// println!("Got = {:?}", val);
/// }
/// ```
pub fn get(&mut self, key: &str) -> crate::Result<Option<Bytes>> {
self.rt.block_on(self.inner.get(key))
}
/// Set `key` to hold the given `value`.
///
/// The `value` is associated with `key` until it is overwritten by the next
/// call to `set` or it is removed.
///
/// If key already holds a value, it is overwritten. Any previous time to
/// live associated with the key is discarded on successful SET operation.
///
/// # Examples
///
/// Demonstrates basic usage.
///
/// ```no_run
/// use mini_redis::blocking_client;
///
/// fn main() {
/// let mut client = blocking_client::connect("localhost:6379").unwrap();
///
/// client.set("foo", "bar".into()).unwrap();
///
/// // Getting the value immediately works
/// let val = client.get("foo").unwrap().unwrap();
/// assert_eq!(val, "bar");
/// }
/// ```
pub fn set(&mut self, key: &str, value: Bytes) -> crate::Result<()> {
self.rt.block_on(self.inner.set(key, value))
}
/// Set `key` to hold the given `value`. The value expires after `expiration`
///
/// The `value` is associated with `key` until one of the following:
/// - it expires.
/// - it is overwritten by the next call to `set`.
/// - it is removed.
///
/// If key already holds a value, it is overwritten. Any previous time to
/// live associated with the key is discarded on a successful SET operation.
///
/// # Examples
///
/// Demonstrates basic usage. This example is not **guaranteed** to always
/// work as it relies on time based logic and assumes the client and server
/// stay relatively synchronized in time. The real world tends to not be so
/// favorable.
///
/// ```no_run
/// use mini_redis::blocking_client;
/// use std::thread;
/// use std::time::Duration;
///
/// fn main() {
/// let ttl = Duration::from_millis(500);
/// let mut client = blocking_client::connect("localhost:6379").unwrap();
///
/// client.set_expires("foo", "bar".into(), ttl).unwrap();
///
/// // Getting the value immediately works
/// let val = client.get("foo").unwrap().unwrap();
/// assert_eq!(val, "bar");
///
/// // Wait for the TTL to expire
/// thread::sleep(ttl);
///
/// let val = client.get("foo").unwrap();
/// assert!(val.is_some());
/// }
/// ```
pub fn set_expires(
&mut self,
key: &str,
value: Bytes,
expiration: Duration,
) -> crate::Result<()> {
self.rt
.block_on(self.inner.set_expires(key, value, expiration))
}
/// Posts `message` to the given `channel`.
///
/// Returns the number of subscribers currently listening on the channel.
/// There is no guarantee that these subscribers receive the message as they
/// may disconnect at any time.
///
/// # Examples
///
/// Demonstrates basic usage.
///
/// ```no_run
/// use mini_redis::blocking_client;
///
/// fn main() {
/// let mut client = blocking_client::connect("localhost:6379").unwrap();
///
/// let val = client.publish("foo", "bar".into()).unwrap();
/// println!("Got = {:?}", val);
/// }
/// ```
pub fn publish(&mut self, channel: &str, message: Bytes) -> crate::Result<u64> {
self.rt.block_on(self.inner.publish(channel, message))
}
/// Subscribes the client to the specified channels.
///
/// Once a client issues a subscribe command, it may no longer issue any
/// non-pub/sub commands. The function consumes `self` and returns a
/// `BlockingSubscriber`.
///
/// The `BlockingSubscriber` value is used to receive messages as well as
/// manage the list of channels the client is subscribed to.
pub fn subscribe(self, channels: Vec<String>) -> crate::Result<BlockingSubscriber> {
let subscriber = self.rt.block_on(self.inner.subscribe(channels))?;
Ok(BlockingSubscriber {
inner: subscriber,
rt: self.rt,
})
}
}
impl BlockingSubscriber {
/// Returns the set of channels currently subscribed to.
pub fn get_subscribed(&self) -> &[String] {
self.inner.get_subscribed()
}
/// Receive the next message published on a subscribed channel, waiting if
/// necessary.
///
/// `None` indicates the subscription has been terminated.
pub fn next_message(&mut self) -> crate::Result<Option<Message>> {
self.rt.block_on(self.inner.next_message())
}
/// Convert the subscriber into an `Iterator` yielding new messages published
/// on subscribed channels.
pub fn into_iter(self) -> impl Iterator<Item = crate::Result<Message>> {
SubscriberIterator {
inner: self.inner,
rt: self.rt,
}
}
/// Subscribe to a list of new channels
pub fn subscribe(&mut self, channels: &[String]) -> crate::Result<()> {
self.rt.block_on(self.inner.subscribe(channels))
}
/// Unsubscribe to a list of new channels
pub fn unsubscribe(&mut self, channels: &[String]) -> crate::Result<()> {
self.rt.block_on(self.inner.unsubscribe(channels))
}
}
impl Iterator for SubscriberIterator {
type Item = crate::Result<Message>;
fn next(&mut self) -> Option<crate::Result<Message>> {
self.rt.block_on(self.inner.next_message()).transpose()
}
}

View File

@ -0,0 +1,120 @@
use crate::client::Client;
use crate::Result;
use bytes::Bytes;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::oneshot;
/// Create a new client request buffer
///
/// The `Client` performs Redis commands directly on the TCP connection. Only a
/// single request may be in-flight at a given time and operations require
/// mutable access to the `Client` handle. This prevents using a single Redis
/// connection from multiple Tokio tasks.
///
/// The strategy for dealing with this class of problem is to spawn a dedicated
/// Tokio task to manage the Redis connection and using "message passing" to
/// operate on the connection. Commands are pushed into a channel. The
/// connection task pops commands off of the channel and applies them to the
/// Redis connection. When the response is received, it is forwarded to the
/// original requester.
///
/// The returned `Buffer` handle may be cloned before passing the new handle to
/// separate tasks.
pub fn buffer(client: Client) -> Buffer {
// Setting the message limit to a hard coded value of 32. in a real-app, the
// buffer size should be configurable, but we don't need to do that here.
let (tx, rx) = channel(32);
// Spawn a task to process requests for the connection.
tokio::spawn(async move { run(client, rx).await });
// Return the `Buffer` handle.
Buffer { tx }
}
// Enum used to message pass the requested command from the `Buffer` handle
#[derive(Debug)]
enum Command {
Get(String),
Set(String, Bytes),
}
// Message type sent over the channel to the connection task.
//
// `Command` is the command to forward to the connection.
//
// `oneshot::Sender` is a channel type that sends a **single** value. It is used
// here to send the response received from the connection back to the original
// requester.
type Message = (Command, oneshot::Sender<Result<Option<Bytes>>>);
/// Receive commands sent through the channel and forward them to client. The
/// response is returned back to the caller via a `oneshot`.
async fn run(mut client: Client, mut rx: Receiver<Message>) {
// Repeatedly pop messages from the channel. A return value of `None`
// indicates that all `Buffer` handles have dropped and there will never be
// another message sent on the channel.
while let Some((cmd, tx)) = rx.recv().await {
// The command is forwarded to the connection
let response = match cmd {
Command::Get(key) => client.get(&key).await,
Command::Set(key, value) => client.set(&key, value).await.map(|_| None),
};
// Send the response back to the caller.
//
// Failing to send the message indicates the `rx` half dropped
// before receiving the message. This is a normal runtime event.
let _ = tx.send(response);
}
}
#[derive(Clone)]
pub struct Buffer {
tx: Sender<Message>,
}
impl Buffer {
/// Get the value of a key.
///
/// Same as `Client::get` but requests are **buffered** until the associated
/// connection has the ability to send the request.
pub async fn get(&mut self, key: &str) -> Result<Option<Bytes>> {
// Initialize a new `Get` command to send via the channel.
let get = Command::Get(key.into());
// Initialize a new oneshot to be used to receive the response back from the connection.
let (tx, rx) = oneshot::channel();
// Send the request
self.tx.send((get, tx)).await?;
// Await the response
match rx.await {
Ok(res) => res,
Err(err) => Err(err.into()),
}
}
/// Set `key` to hold the given `value`.
///
/// Same as `Client::set` but requests are **buffered** until the associated
/// connection has the ability to send the request
pub async fn set(&mut self, key: &str, value: Bytes) -> Result<()> {
// Initialize a new `Set` command to send via the channel.
let set = Command::Set(key.into(), value);
// Initialize a new oneshot to be used to receive the response back from the connection.
let (tx, rx) = oneshot::channel();
// Send the request
self.tx.send((set, tx)).await?;
// Await the response
match rx.await {
Ok(res) => res.map(|_| ()),
Err(err) => Err(err.into()),
}
}
}

View File

@ -0,0 +1,473 @@
//! Minimal Redis client implementation
//!
//! Provides an async connect and methods for issuing the supported commands.
use crate::cmd::{Get, Publish, Set, Subscribe, Unsubscribe};
use crate::{Connection, Frame};
use async_stream::try_stream;
use bytes::Bytes;
use std::io::{Error, ErrorKind};
use std::time::Duration;
use tokio::net::{TcpStream, ToSocketAddrs};
use tokio_stream::Stream;
use tracing::{debug, instrument};
/// Established connection with a Redis server.
///
/// Backed by a single `TcpStream`, `Client` provides basic network client
/// functionality (no pooling, retrying, ...). Connections are established using
/// the [`connect`](fn@connect) function.
///
/// Requests are issued using the various methods of `Client`.
pub struct Client {
/// The TCP connection decorated with the redis protocol encoder / decoder
/// implemented using a buffered `TcpStream`.
///
/// When `Listener` receives an inbound connection, the `TcpStream` is
/// passed to `Connection::new`, which initializes the associated buffers.
/// `Connection` allows the handler to operate at the "frame" level and keep
/// the byte level protocol parsing details encapsulated in `Connection`.
connection: Connection,
}
/// A client that has entered pub/sub mode.
///
/// Once clients subscribe to a channel, they may only perform pub/sub related
/// commands. The `Client` type is transitioned to a `Subscriber` type in order
/// to prevent non-pub/sub methods from being called.
pub struct Subscriber {
/// The subscribed client.
client: Client,
/// The set of channels to which the `Subscriber` is currently subscribed.
subscribed_channels: Vec<String>,
}
/// A message received on a subscribed channel.
#[derive(Debug, Clone)]
pub struct Message {
pub channel: String,
pub content: Bytes,
}
/// Establish a connection with the Redis server located at `addr`.
///
/// `addr` may be any type that can be asynchronously converted to a
/// `SocketAddr`. This includes `SocketAddr` and strings. The `ToSocketAddrs`
/// trait is the Tokio version and not the `std` version.
///
/// # Examples
///
/// ```no_run
/// use mini_redis::client;
///
/// #[tokio::main]
/// async fn main() {
/// let client = match client::connect("localhost:6379").await {
/// Ok(client) => client,
/// Err(_) => panic!("failed to establish connection"),
/// };
/// # drop(client);
/// }
/// ```
///
pub async fn connect<T: ToSocketAddrs>(addr: T) -> crate::Result<Client> {
// The `addr` argument is passed directly to `TcpStream::connect`. This
// performs any asynchronous DNS lookup and attempts to establish the TCP
// connection. An error at either step returns an error, which is then
// bubbled up to the caller of `mini_redis` connect.
let socket = TcpStream::connect(addr).await?;
// Initialize the connection state. This allocates read/write buffers to
// perform redis protocol frame parsing.
let connection = Connection::new(socket);
Ok(Client { connection })
}
impl Client {
/// Get the value of key.
///
/// If the key does not exist the special value `None` is returned.
///
/// # Examples
///
/// Demonstrates basic usage.
///
/// ```no_run
/// use mini_redis::client;
///
/// #[tokio::main]
/// async fn main() {
/// let mut client = client::connect("localhost:6379").await.unwrap();
///
/// let val = client.get("foo").await.unwrap();
/// println!("Got = {:?}", val);
/// }
/// ```
#[instrument(skip(self))]
pub async fn get(&mut self, key: &str) -> crate::Result<Option<Bytes>> {
// Create a `Get` command for the `key` and convert it to a frame.
let frame = Get::new(key).into_frame();
debug!(request = ?frame);
// Write the frame to the socket. This writes the full frame to the
// socket, waiting if necessary.
self.connection.write_frame(&frame).await?;
// Wait for the response from the server
//
// Both `Simple` and `Bulk` frames are accepted. `Null` represents the
// key not being present and `None` is returned.
match self.read_response().await? {
Frame::Simple(value) => Ok(Some(value.into())),
Frame::Bulk(value) => Ok(Some(value)),
Frame::Null => Ok(None),
frame => Err(frame.to_error()),
}
}
/// Set `key` to hold the given `value`.
///
/// The `value` is associated with `key` until it is overwritten by the next
/// call to `set` or it is removed.
///
/// If key already holds a value, it is overwritten. Any previous time to
/// live associated with the key is discarded on successful SET operation.
///
/// # Examples
///
/// Demonstrates basic usage.
///
/// ```no_run
/// use mini_redis::client;
///
/// #[tokio::main]
/// async fn main() {
/// let mut client = client::connect("localhost:6379").await.unwrap();
///
/// client.set("foo", "bar".into()).await.unwrap();
///
/// // Getting the value immediately works
/// let val = client.get("foo").await.unwrap().unwrap();
/// assert_eq!(val, "bar");
/// }
/// ```
#[instrument(skip(self))]
pub async fn set(&mut self, key: &str, value: Bytes) -> crate::Result<()> {
// Create a `Set` command and pass it to `set_cmd`. A separate method is
// used to set a value with an expiration. The common parts of both
// functions are implemented by `set_cmd`.
self.set_cmd(Set::new(key, value, None)).await
}
/// Set `key` to hold the given `value`. The value expires after `expiration`
///
/// The `value` is associated with `key` until one of the following:
/// - it expires.
/// - it is overwritten by the next call to `set`.
/// - it is removed.
///
/// If key already holds a value, it is overwritten. Any previous time to
/// live associated with the key is discarded on a successful SET operation.
///
/// # Examples
///
/// Demonstrates basic usage. This example is not **guaranteed** to always
/// work as it relies on time based logic and assumes the client and server
/// stay relatively synchronized in time. The real world tends to not be so
/// favorable.
///
/// ```no_run
/// use mini_redis::client;
/// use tokio::time;
/// use std::time::Duration;
///
/// #[tokio::main]
/// async fn main() {
/// let ttl = Duration::from_millis(500);
/// let mut client = client::connect("localhost:6379").await.unwrap();
///
/// client.set_expires("foo", "bar".into(), ttl).await.unwrap();
///
/// // Getting the value immediately works
/// let val = client.get("foo").await.unwrap().unwrap();
/// assert_eq!(val, "bar");
///
/// // Wait for the TTL to expire
/// time::sleep(ttl).await;
///
/// let val = client.get("foo").await.unwrap();
/// assert!(val.is_some());
/// }
/// ```
#[instrument(skip(self))]
pub async fn set_expires(
&mut self,
key: &str,
value: Bytes,
expiration: Duration,
) -> crate::Result<()> {
// Create a `Set` command and pass it to `set_cmd`. A separate method is
// used to set a value with an expiration. The common parts of both
// functions are implemented by `set_cmd`.
self.set_cmd(Set::new(key, value, Some(expiration))).await
}
/// The core `SET` logic, used by both `set` and `set_expires.
async fn set_cmd(&mut self, cmd: Set) -> crate::Result<()> {
// Convert the `Set` command into a frame
let frame = cmd.into_frame();
debug!(request = ?frame);
// Write the frame to the socket. This writes the full frame to the
// socket, waiting if necessary.
self.connection.write_frame(&frame).await?;
// Wait for the response from the server. On success, the server
// responds simply with `OK`. Any other response indicates an error.
match self.read_response().await? {
Frame::Simple(response) if response == "OK" => Ok(()),
frame => Err(frame.to_error()),
}
}
/// Posts `message` to the given `channel`.
///
/// Returns the number of subscribers currently listening on the channel.
/// There is no guarantee that these subscribers receive the message as they
/// may disconnect at any time.
///
/// # Examples
///
/// Demonstrates basic usage.
///
/// ```no_run
/// use mini_redis::client;
///
/// #[tokio::main]
/// async fn main() {
/// let mut client = client::connect("localhost:6379").await.unwrap();
///
/// let val = client.publish("foo", "bar".into()).await.unwrap();
/// println!("Got = {:?}", val);
/// }
/// ```
#[instrument(skip(self))]
pub async fn publish(&mut self, channel: &str, message: Bytes) -> crate::Result<u64> {
// Convert the `Publish` command into a frame
let frame = Publish::new(channel, message).into_frame();
debug!(request = ?frame);
// Write the frame to the socket
self.connection.write_frame(&frame).await?;
// Read the response
match self.read_response().await? {
Frame::Integer(response) => Ok(response),
frame => Err(frame.to_error()),
}
}
/// Subscribes the client to the specified channels.
///
/// Once a client issues a subscribe command, it may no longer issue any
/// non-pub/sub commands. The function consumes `self` and returns a `Subscriber`.
///
/// The `Subscriber` value is used to receive messages as well as manage the
/// list of channels the client is subscribed to.
#[instrument(skip(self))]
pub async fn subscribe(mut self, channels: Vec<String>) -> crate::Result<Subscriber> {
// Issue the subscribe command to the server and wait for confirmation.
// The client will then have been transitioned into the "subscriber"
// state and may only issue pub/sub commands from that point on.
self.subscribe_cmd(&channels).await?;
// Return the `Subscriber` type
Ok(Subscriber {
client: self,
subscribed_channels: channels,
})
}
/// The core `SUBSCRIBE` logic, used by misc subscribe fns
async fn subscribe_cmd(&mut self, channels: &[String]) -> crate::Result<()> {
// Convert the `Subscribe` command into a frame
let frame = Subscribe::new(&channels).into_frame();
debug!(request = ?frame);
// Write the frame to the socket
self.connection.write_frame(&frame).await?;
// For each channel being subscribed to, the server responds with a
// message confirming subscription to that channel.
for channel in channels {
// Read the response
let response = self.read_response().await?;
// Verify it is confirmation of subscription.
match response {
Frame::Array(ref frame) => match frame.as_slice() {
// The server responds with an array frame in the form of:
//
// ```
// [ "subscribe", channel, num-subscribed ]
// ```
//
// where channel is the name of the channel and
// num-subscribed is the number of channels that the client
// is currently subscribed to.
[subscribe, schannel, ..]
if *subscribe == "subscribe" && *schannel == channel => {}
_ => return Err(response.to_error()),
},
frame => return Err(frame.to_error()),
};
}
Ok(())
}
/// Reads a response frame from the socket.
///
/// If an `Error` frame is received, it is converted to `Err`.
async fn read_response(&mut self) -> crate::Result<Frame> {
let response = self.connection.read_frame().await?;
debug!(?response);
match response {
// Error frames are converted to `Err`
Some(Frame::Error(msg)) => Err(msg.into()),
Some(frame) => Ok(frame),
None => {
// Receiving `None` here indicates the server has closed the
// connection without sending a frame. This is unexpected and is
// represented as a "connection reset by peer" error.
let err = Error::new(ErrorKind::ConnectionReset, "connection reset by server");
Err(err.into())
}
}
}
}
impl Subscriber {
/// Returns the set of channels currently subscribed to.
pub fn get_subscribed(&self) -> &[String] {
&self.subscribed_channels
}
/// Receive the next message published on a subscribed channel, waiting if
/// necessary.
///
/// `None` indicates the subscription has been terminated.
pub async fn next_message(&mut self) -> crate::Result<Option<Message>> {
match self.client.connection.read_frame().await? {
Some(mframe) => {
debug!(?mframe);
match mframe {
Frame::Array(ref frame) => match frame.as_slice() {
[message, channel, content] if *message == "message" => Ok(Some(Message {
channel: channel.to_string(),
content: Bytes::from(content.to_string()),
})),
_ => Err(mframe.to_error()),
},
frame => Err(frame.to_error()),
}
}
None => Ok(None),
}
}
/// Convert the subscriber into a `Stream` yielding new messages published
/// on subscribed channels.
///
/// `Subscriber` does not implement stream itself as doing so with safe code
/// is non trivial. The usage of async/await would require a manual Stream
/// implementation to use `unsafe` code. Instead, a conversion function is
/// provided and the returned stream is implemented with the help of the
/// `async-stream` crate.
pub fn into_stream(mut self) -> impl Stream<Item = crate::Result<Message>> {
// Uses the `try_stream` macro from the `async-stream` crate. Generators
// are not stable in Rust. The crate uses a macro to simulate generators
// on top of async/await. There are limitations, so read the
// documentation there.
try_stream! {
while let Some(message) = self.next_message().await? {
yield message;
}
}
}
/// Subscribe to a list of new channels
#[instrument(skip(self))]
pub async fn subscribe(&mut self, channels: &[String]) -> crate::Result<()> {
// Issue the subscribe command
self.client.subscribe_cmd(channels).await?;
// Update the set of subscribed channels.
self.subscribed_channels
.extend(channels.iter().map(Clone::clone));
Ok(())
}
/// Unsubscribe to a list of new channels
#[instrument(skip(self))]
pub async fn unsubscribe(&mut self, channels: &[String]) -> crate::Result<()> {
let frame = Unsubscribe::new(&channels).into_frame();
debug!(request = ?frame);
// Write the frame to the socket
self.client.connection.write_frame(&frame).await?;
// if the input channel list is empty, server acknowledges as unsubscribing
// from all subscribed channels, so we assert that the unsubscribe list received
// matches the client subscribed one
let num = if channels.is_empty() {
self.subscribed_channels.len()
} else {
channels.len()
};
// Read the response
for _ in 0..num {
let response = self.client.read_response().await?;
match response {
Frame::Array(ref frame) => match frame.as_slice() {
[unsubscribe, channel, ..] if *unsubscribe == "unsubscribe" => {
let len = self.subscribed_channels.len();
if len == 0 {
// There must be at least one channel
return Err(response.to_error());
}
// unsubscribed channel should exist in the subscribed list at this point
self.subscribed_channels.retain(|c| *channel != &c[..]);
// Only a single channel should be removed from the
// list of subscribed channels.
if self.subscribed_channels.len() != len - 1 {
return Err(response.to_error());
}
}
_ => return Err(response.to_error()),
},
frame => return Err(frame.to_error()),
};
}
Ok(())
}
}

View File

@ -0,0 +1,93 @@
use crate::{Connection, Db, Frame, Parse};
use bytes::Bytes;
use tracing::{debug, instrument};
/// Get the value of key.
///
/// If the key does not exist the special value nil is returned. An error is
/// returned if the value stored at key is not a string, because GET only
/// handles string values.
#[derive(Debug)]
pub struct Get {
/// Name of the key to get
key: String,
}
impl Get {
/// Create a new `Get` command which fetches `key`.
pub fn new(key: impl ToString) -> Get {
Get {
key: key.to_string(),
}
}
/// Get the key
pub fn key(&self) -> &str {
&self.key
}
/// Parse a `Get` instance from a received frame.
///
/// The `Parse` argument provides a cursor-like API to read fields from the
/// `Frame`. At this point, the entire frame has already been received from
/// the socket.
///
/// The `GET` string has already been consumed.
///
/// # Returns
///
/// Returns the `Get` value on success. If the frame is malformed, `Err` is
/// returned.
///
/// # Format
///
/// Expects an array frame containing two entries.
///
/// ```text
/// GET key
/// ```
pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result<Get> {
// The `GET` string has already been consumed. The next value is the
// name of the key to get. If the next value is not a string or the
// input is fully consumed, then an error is returned.
let key = parse.next_string()?;
Ok(Get { key })
}
/// Apply the `Get` command to the specified `Db` instance.
///
/// The response is written to `dst`. This is called by the server in order
/// to execute a received command.
#[instrument(skip(self, db, dst))]
pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> crate::Result<()> {
// Get the value from the shared database state
let response = if let Some(value) = db.get(&self.key) {
// If a value is present, it is written to the client in "bulk"
// format.
Frame::Bulk(value)
} else {
// If there is no value, `Null` is written.
Frame::Null
};
debug!(?response);
// Write the response back to the client
dst.write_frame(&response).await?;
Ok(())
}
/// Converts the command into an equivalent `Frame`.
///
/// This is called by the client when encoding a `Get` command to send to
/// the server.
pub(crate) fn into_frame(self) -> Frame {
let mut frame = Frame::array();
frame.push_bulk(Bytes::from("get".as_bytes()));
frame.push_bulk(Bytes::from(self.key.into_bytes()));
frame
}
}

View File

@ -0,0 +1,116 @@
mod get;
pub use get::Get;
mod publish;
pub use publish::Publish;
mod set;
pub use set::Set;
mod subscribe;
pub use subscribe::{Subscribe, Unsubscribe};
mod unknown;
pub use unknown::Unknown;
use crate::{Connection, Db, Frame, Parse, ParseError, Shutdown};
/// Enumeration of supported Redis commands.
///
/// Methods called on `Command` are delegated to the command implementation.
#[derive(Debug)]
pub enum Command {
Get(Get),
Publish(Publish),
Set(Set),
Subscribe(Subscribe),
Unsubscribe(Unsubscribe),
Unknown(Unknown),
}
impl Command {
/// Parse a command from a received frame.
///
/// The `Frame` must represent a Redis command supported by `mini-redis` and
/// be the array variant.
///
/// # Returns
///
/// On success, the command value is returned, otherwise, `Err` is returned.
pub fn from_frame(frame: Frame) -> crate::Result<Command> {
// The frame value is decorated with `Parse`. `Parse` provides a
// "cursor" like API which makes parsing the command easier.
//
// The frame value must be an array variant. Any other frame variants
// result in an error being returned.
let mut parse = Parse::new(frame)?;
// All redis commands begin with the command name as a string. The name
// is read and converted to lower cases in order to do case sensitive
// matching.
let command_name = parse.next_string()?.to_lowercase();
// Match the command name, delegating the rest of the parsing to the
// specific command.
let command = match &command_name[..] {
"get" => Command::Get(Get::parse_frames(&mut parse)?),
"publish" => Command::Publish(Publish::parse_frames(&mut parse)?),
"set" => Command::Set(Set::parse_frames(&mut parse)?),
"subscribe" => Command::Subscribe(Subscribe::parse_frames(&mut parse)?),
"unsubscribe" => Command::Unsubscribe(Unsubscribe::parse_frames(&mut parse)?),
_ => {
// The command is not recognized and an Unknown command is
// returned.
//
// `return` is called here to skip the `finish()` call below. As
// the command is not recognized, there is most likely
// unconsumed fields remaining in the `Parse` instance.
return Ok(Command::Unknown(Unknown::new(command_name)));
}
};
// Check if there is any remaining unconsumed fields in the `Parse`
// value. If fields remain, this indicates an unexpected frame format
// and an error is returned.
parse.finish()?;
// The command has been successfully parsed
Ok(command)
}
/// Apply the command to the specified `Db` instance.
///
/// The response is written to `dst`. This is called by the server in order
/// to execute a received command.
pub(crate) async fn apply(
self,
db: &Db,
dst: &mut Connection,
shutdown: &mut Shutdown,
) -> crate::Result<()> {
use Command::*;
match self {
Get(cmd) => cmd.apply(db, dst).await,
Publish(cmd) => cmd.apply(db, dst).await,
Set(cmd) => cmd.apply(db, dst).await,
Subscribe(cmd) => cmd.apply(db, dst, shutdown).await,
Unknown(cmd) => cmd.apply(dst).await,
// `Unsubscribe` cannot be applied. It may only be received from the
// context of a `Subscribe` command.
Unsubscribe(_) => Err("`Unsubscribe` is unsupported in this context".into()),
}
}
/// Returns the command name
pub(crate) fn get_name(&self) -> &str {
match self {
Command::Get(_) => "get",
Command::Publish(_) => "pub",
Command::Set(_) => "set",
Command::Subscribe(_) => "subscribe",
Command::Unsubscribe(_) => "unsubscribe",
Command::Unknown(cmd) => cmd.get_name(),
}
}
}

View File

@ -0,0 +1,101 @@
use crate::{Connection, Db, Frame, Parse};
use bytes::Bytes;
/// Posts a message to the given channel.
///
/// Send a message into a channel without any knowledge of individual consumers.
/// Consumers may subscribe to channels in order to receive the messages.
///
/// Channel names have no relation to the key-value namespace. Publishing on a
/// channel named "foo" has no relation to setting the "foo" key.
#[derive(Debug)]
pub struct Publish {
/// Name of the channel on which the message should be published.
channel: String,
/// The message to publish.
message: Bytes,
}
impl Publish {
/// Create a new `Publish` command which sends `message` on `channel`.
pub(crate) fn new(channel: impl ToString, message: Bytes) -> Publish {
Publish {
channel: channel.to_string(),
message,
}
}
/// Parse a `Publish` instance from a received frame.
///
/// The `Parse` argument provides a cursor-like API to read fields from the
/// `Frame`. At this point, the entire frame has already been received from
/// the socket.
///
/// The `PUBLISH` string has already been consumed.
///
/// # Returns
///
/// On success, the `Publish` value is returned. If the frame is malformed,
/// `Err` is returned.
///
/// # Format
///
/// Expects an array frame containing three entries.
///
/// ```text
/// PUBLISH channel message
/// ```
pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result<Publish> {
// The `PUBLISH` string has already been consumed. Extract the `channel`
// and `message` values from the frame.
//
// The `channel` must be a valid string.
let channel = parse.next_string()?;
// The `message` is arbitrary bytes.
let message = parse.next_bytes()?;
Ok(Publish { channel, message })
}
/// Apply the `Publish` command to the specified `Db` instance.
///
/// The response is written to `dst`. This is called by the server in order
/// to execute a received command.
pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> crate::Result<()> {
// The shared state contains the `tokio::sync::broadcast::Sender` for
// all active channels. Calling `db.publish` dispatches the message into
// the appropriate channel.
//
// The number of subscribers currently listening on the channel is
// returned. This does not mean that `num_subscriber` channels will
// receive the message. Subscribers may drop before receiving the
// message. Given this, `num_subscribers` should only be used as a
// "hint".
let num_subscribers = db.publish(&self.channel, self.message);
// The number of subscribers is returned as the response to the publish
// request.
let response = Frame::Integer(num_subscribers as u64);
// Write the frame to the client.
dst.write_frame(&response).await?;
Ok(())
}
/// Converts the command into an equivalent `Frame`.
///
/// This is called by the client when encoding a `Publish` command to send
/// to the server.
pub(crate) fn into_frame(self) -> Frame {
let mut frame = Frame::array();
frame.push_bulk(Bytes::from("publish".as_bytes()));
frame.push_bulk(Bytes::from(self.channel.into_bytes()));
frame.push_bulk(self.message);
frame
}
}

View File

@ -0,0 +1,161 @@
use crate::cmd::{Parse, ParseError};
use crate::{Connection, Db, Frame};
use bytes::Bytes;
use std::time::Duration;
use tracing::{debug, instrument};
/// Set `key` to hold the string `value`.
///
/// If `key` already holds a value, it is overwritten, regardless of its type.
/// Any previous time to live associated with the key is discarded on successful
/// SET operation.
///
/// # Options
///
/// Currently, the following options are supported:
///
/// * EX `seconds` -- Set the specified expire time, in seconds.
/// * PX `milliseconds` -- Set the specified expire time, in milliseconds.
#[derive(Debug)]
pub struct Set {
/// the lookup key
key: String,
/// the value to be stored
value: Bytes,
/// When to expire the key
expire: Option<Duration>,
}
impl Set {
/// Create a new `Set` command which sets `key` to `value`.
///
/// If `expire` is `Some`, the value should expire after the specified
/// duration.
pub fn new(key: impl ToString, value: Bytes, expire: Option<Duration>) -> Set {
Set {
key: key.to_string(),
value,
expire,
}
}
/// Get the key
pub fn key(&self) -> &str {
&self.key
}
/// Get the value
pub fn value(&self) -> &Bytes {
&self.value
}
/// Get the expire
pub fn expire(&self) -> Option<Duration> {
self.expire
}
/// Parse a `Set` instance from a received frame.
///
/// The `Parse` argument provides a cursor-like API to read fields from the
/// `Frame`. At this point, the entire frame has already been received from
/// the socket.
///
/// The `SET` string has already been consumed.
///
/// # Returns
///
/// Returns the `Set` value on success. If the frame is malformed, `Err` is
/// returned.
///
/// # Format
///
/// Expects an array frame containing at least 3 entries.
///
/// ```text
/// SET key value [EX seconds|PX milliseconds]
/// ```
pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result<Set> {
use ParseError::EndOfStream;
// Read the key to set. This is a required field
let key = parse.next_string()?;
// Read the value to set. This is a required field.
let value = parse.next_bytes()?;
// The expiration is optional. If nothing else follows, then it is
// `None`.
let mut expire = None;
// Attempt to parse another string.
match parse.next_string() {
Ok(s) if s.to_uppercase() == "EX" => {
// An expiration is specified in seconds. The next value is an
// integer.
let secs = parse.next_int()?;
expire = Some(Duration::from_secs(secs));
}
Ok(s) if s.to_uppercase() == "PX" => {
// An expiration is specified in milliseconds. The next value is
// an integer.
let ms = parse.next_int()?;
expire = Some(Duration::from_millis(ms));
}
// Currently, mini-redis does not support any of the other SET
// options. An error here results in the connection being
// terminated. Other connections will continue to operate normally.
Ok(_) => return Err("currently `SET` only supports the expiration option".into()),
// The `EndOfStream` error indicates there is no further data to
// parse. In this case, it is a normal run time situation and
// indicates there are no specified `SET` options.
Err(EndOfStream) => {}
// All other errors are bubbled up, resulting in the connection
// being terminated.
Err(err) => return Err(err.into()),
}
Ok(Set { key, value, expire })
}
/// Apply the `Set` command to the specified `Db` instance.
///
/// The response is written to `dst`. This is called by the server in order
/// to execute a received command.
#[instrument(skip(self, db, dst))]
pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> crate::Result<()> {
// Set the value in the shared database state.
db.set(self.key, self.value, self.expire);
// Create a success response and write it to `dst`.
let response = Frame::Simple("OK".to_string());
debug!(?response);
dst.write_frame(&response).await?;
Ok(())
}
/// Converts the command into an equivalent `Frame`.
///
/// This is called by the client when encoding a `Set` command to send to
/// the server.
pub(crate) fn into_frame(self) -> Frame {
let mut frame = Frame::array();
frame.push_bulk(Bytes::from("set".as_bytes()));
frame.push_bulk(Bytes::from(self.key.into_bytes()));
frame.push_bulk(self.value);
if let Some(ms) = self.expire {
// Expirations in Redis procotol can be specified in two ways
// 1. SET key value EX seconds
// 2. SET key value PX milliseconds
// We the second option because it allows greater precision and
// src/bin/cli.rs parses the expiration argument as milliseconds
// in duration_from_ms_str()
frame.push_bulk(Bytes::from("px".as_bytes()));
frame.push_int(ms.as_millis() as u64);
}
frame
}
}

View File

@ -0,0 +1,351 @@
use crate::cmd::{Parse, ParseError, Unknown};
use crate::{Command, Connection, Db, Frame, Shutdown};
use bytes::Bytes;
use std::pin::Pin;
use tokio::select;
use tokio::sync::broadcast;
use tokio_stream::{Stream, StreamExt, StreamMap};
/// Subscribes the client to one or more channels.
///
/// Once the client enters the subscribed state, it is not supposed to issue any
/// other commands, except for additional SUBSCRIBE, PSUBSCRIBE, UNSUBSCRIBE,
/// PUNSUBSCRIBE, PING and QUIT commands.
#[derive(Debug)]
pub struct Subscribe {
channels: Vec<String>,
}
/// Unsubscribes the client from one or more channels.
///
/// When no channels are specified, the client is unsubscribed from all the
/// previously subscribed channels.
#[derive(Clone, Debug)]
pub struct Unsubscribe {
channels: Vec<String>,
}
/// Stream of messages. The stream receives messages from the
/// `broadcast::Receiver`. We use `stream!` to create a `Stream` that consumes
/// messages. Because `stream!` values cannot be named, we box the stream using
/// a trait object.
type Messages = Pin<Box<dyn Stream<Item = Bytes> + Send>>;
impl Subscribe {
/// Creates a new `Subscribe` command to listen on the specified channels.
pub(crate) fn new(channels: &[String]) -> Subscribe {
Subscribe {
channels: channels.to_vec(),
}
}
/// Parse a `Subscribe` instance from a received frame.
///
/// The `Parse` argument provides a cursor-like API to read fields from the
/// `Frame`. At this point, the entire frame has already been received from
/// the socket.
///
/// The `SUBSCRIBE` string has already been consumed.
///
/// # Returns
///
/// On success, the `Subscribe` value is returned. If the frame is
/// malformed, `Err` is returned.
///
/// # Format
///
/// Expects an array frame containing two or more entries.
///
/// ```text
/// SUBSCRIBE channel [channel ...]
/// ```
pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result<Subscribe> {
use ParseError::EndOfStream;
// The `SUBSCRIBE` string has already been consumed. At this point,
// there is one or more strings remaining in `parse`. These represent
// the channels to subscribe to.
//
// Extract the first string. If there is none, the the frame is
// malformed and the error is bubbled up.
let mut channels = vec![parse.next_string()?];
// Now, the remainder of the frame is consumed. Each value must be a
// string or the frame is malformed. Once all values in the frame have
// been consumed, the command is fully parsed.
loop {
match parse.next_string() {
// A string has been consumed from the `parse`, push it into the
// list of channels to subscribe to.
Ok(s) => channels.push(s),
// The `EndOfStream` error indicates there is no further data to
// parse.
Err(EndOfStream) => break,
// All other errors are bubbled up, resulting in the connection
// being terminated.
Err(err) => return Err(err.into()),
}
}
Ok(Subscribe { channels })
}
/// Apply the `Subscribe` command to the specified `Db` instance.
///
/// This function is the entry point and includes the initial list of
/// channels to subscribe to. Additional `subscribe` and `unsubscribe`
/// commands may be received from the client and the list of subscriptions
/// are updated accordingly.
///
/// [here]: https://redis.io/topics/pubsub
pub(crate) async fn apply(
mut self,
db: &Db,
dst: &mut Connection,
shutdown: &mut Shutdown,
) -> crate::Result<()> {
// Each individual channel subscription is handled using a
// `sync::broadcast` channel. Messages are then fanned out to all
// clients currently subscribed to the channels.
//
// An individual client may subscribe to multiple channels and may
// dynamically add and remove channels from its subscription set. To
// handle this, a `StreamMap` is used to track active subscriptions. The
// `StreamMap` merges messages from individual broadcast channels as
// they are received.
let mut subscriptions = StreamMap::new();
loop {
// `self.channels` is used to track additional channels to subscribe
// to. When new `SUBSCRIBE` commands are received during the
// execution of `apply`, the new channels are pushed onto this vec.
for channel_name in self.channels.drain(..) {
subscribe_to_channel(channel_name, &mut subscriptions, db, dst).await?;
}
// Wait for one of the following to happen:
//
// - Receive a message from one of the subscribed channels.
// - Receive a subscribe or unsubscribe command from the client.
// - A server shutdown signal.
select! {
// Receive messages from subscribed channels
Some((channel_name, msg)) = subscriptions.next() => {
dst.write_frame(&make_message_frame(channel_name, msg)).await?;
},
res = dst.read_frame() => {
let frame = match res? {
Some(frame) => frame,
// This happens if the remote client has disconnected.
None => return Ok(())
};
handle_command(
frame,
&mut self.channels,
&mut subscriptions,
dst,
).await?;
},
_ = shutdown.recv() => {
return Ok(());
}
};
}
}
/// Converts the command into an equivalent `Frame`.
///
/// This is called by the client when encoding a `Subscribe` command to send
/// to the server.
pub(crate) fn into_frame(self) -> Frame {
let mut frame = Frame::array();
frame.push_bulk(Bytes::from("subscribe".as_bytes()));
for channel in self.channels {
frame.push_bulk(Bytes::from(channel.into_bytes()));
}
frame
}
}
async fn subscribe_to_channel(
channel_name: String,
subscriptions: &mut StreamMap<String, Messages>,
db: &Db,
dst: &mut Connection,
) -> crate::Result<()> {
let mut rx = db.subscribe(channel_name.clone());
// Subscribe to the channel.
let rx = Box::pin(async_stream::stream! {
loop {
match rx.recv().await {
Ok(msg) => yield msg,
// If we lagged in consuming messages, just resume.
Err(broadcast::error::RecvError::Lagged(_)) => {}
Err(_) => break,
}
}
});
// Track subscription in this client's subscription set.
subscriptions.insert(channel_name.clone(), rx);
// Respond with the successful subscription
let response = make_subscribe_frame(channel_name, subscriptions.len());
dst.write_frame(&response).await?;
Ok(())
}
/// Handle a command received while inside `Subscribe::apply`. Only subscribe
/// and unsubscribe commands are permitted in this context.
///
/// Any new subscriptions are appended to `subscribe_to` instead of modifying
/// `subscriptions`.
async fn handle_command(
frame: Frame,
subscribe_to: &mut Vec<String>,
subscriptions: &mut StreamMap<String, Messages>,
dst: &mut Connection,
) -> crate::Result<()> {
// A command has been received from the client.
//
// Only `SUBSCRIBE` and `UNSUBSCRIBE` commands are permitted
// in this context.
match Command::from_frame(frame)? {
Command::Subscribe(subscribe) => {
// The `apply` method will subscribe to the channels we add to this
// vector.
subscribe_to.extend(subscribe.channels.into_iter());
}
Command::Unsubscribe(mut unsubscribe) => {
// If no channels are specified, this requests unsubscribing from
// **all** channels. To implement this, the `unsubscribe.channels`
// vec is populated with the list of channels currently subscribed
// to.
if unsubscribe.channels.is_empty() {
unsubscribe.channels = subscriptions
.keys()
.map(|channel_name| channel_name.to_string())
.collect();
}
for channel_name in unsubscribe.channels {
subscriptions.remove(&channel_name);
let response = make_unsubscribe_frame(channel_name, subscriptions.len());
dst.write_frame(&response).await?;
}
}
command => {
let cmd = Unknown::new(command.get_name());
cmd.apply(dst).await?;
}
}
Ok(())
}
/// Creates the response to a subcribe request.
///
/// All of these functions take the `channel_name` as a `String` instead of
/// a `&str` since `Bytes::from` can reuse the allocation in the `String`, and
/// taking a `&str` would require copying the data. This allows the caller to
/// decide whether to clone the channel name or not.
fn make_subscribe_frame(channel_name: String, num_subs: usize) -> Frame {
let mut response = Frame::array();
response.push_bulk(Bytes::from_static(b"subscribe"));
response.push_bulk(Bytes::from(channel_name));
response.push_int(num_subs as u64);
response
}
/// Creates the response to an unsubcribe request.
fn make_unsubscribe_frame(channel_name: String, num_subs: usize) -> Frame {
let mut response = Frame::array();
response.push_bulk(Bytes::from_static(b"unsubscribe"));
response.push_bulk(Bytes::from(channel_name));
response.push_int(num_subs as u64);
response
}
/// Creates a message informing the client about a new message on a channel that
/// the client subscribes to.
fn make_message_frame(channel_name: String, msg: Bytes) -> Frame {
let mut response = Frame::array();
response.push_bulk(Bytes::from_static(b"message"));
response.push_bulk(Bytes::from(channel_name));
response.push_bulk(msg);
response
}
impl Unsubscribe {
/// Create a new `Unsubscribe` command with the given `channels`.
pub(crate) fn new(channels: &[String]) -> Unsubscribe {
Unsubscribe {
channels: channels.to_vec(),
}
}
/// Parse a `Unsubscribe` instance from a received frame.
///
/// The `Parse` argument provides a cursor-like API to read fields from the
/// `Frame`. At this point, the entire frame has already been received from
/// the socket.
///
/// The `UNSUBSCRIBE` string has already been consumed.
///
/// # Returns
///
/// On success, the `Unsubscribe` value is returned. If the frame is
/// malformed, `Err` is returned.
///
/// # Format
///
/// Expects an array frame containing at least one entry.
///
/// ```text
/// UNSUBSCRIBE [channel [channel ...]]
/// ```
pub(crate) fn parse_frames(parse: &mut Parse) -> Result<Unsubscribe, ParseError> {
use ParseError::EndOfStream;
// There may be no channels listed, so start with an empty vec.
let mut channels = vec![];
// Each entry in the frame must be a string or the frame is malformed.
// Once all values in the frame have been consumed, the command is fully
// parsed.
loop {
match parse.next_string() {
// A string has been consumed from the `parse`, push it into the
// list of channels to unsubscribe from.
Ok(s) => channels.push(s),
// The `EndOfStream` error indicates there is no further data to
// parse.
Err(EndOfStream) => break,
// All other errors are bubbled up, resulting in the connection
// being terminated.
Err(err) => return Err(err),
}
}
Ok(Unsubscribe { channels })
}
/// Converts the command into an equivalent `Frame`.
///
/// This is called by the client when encoding an `Unsubscribe` command to
/// send to the server.
pub(crate) fn into_frame(self) -> Frame {
let mut frame = Frame::array();
frame.push_bulk(Bytes::from("unsubscribe".as_bytes()));
for channel in self.channels {
frame.push_bulk(Bytes::from(channel.into_bytes()));
}
frame
}
}

View File

@ -0,0 +1,37 @@
use crate::{Connection, Frame};
use tracing::{debug, instrument};
/// Represents an "unknown" command. This is not a real `Redis` command.
#[derive(Debug)]
pub struct Unknown {
command_name: String,
}
impl Unknown {
/// Create a new `Unknown` command which responds to unknown commands
/// issued by clients
pub(crate) fn new(key: impl ToString) -> Unknown {
Unknown {
command_name: key.to_string(),
}
}
/// Returns the command name
pub(crate) fn get_name(&self) -> &str {
&self.command_name
}
/// Responds to the client, indicating the command is not recognized.
///
/// This usually means the command is not yet implemented by `mini-redis`.
#[instrument(skip(self, dst))]
pub(crate) async fn apply(self, dst: &mut Connection) -> crate::Result<()> {
let response = Frame::Error(format!("ERR unknown command '{}'", self.command_name));
debug!(?response);
dst.write_frame(&response).await?;
Ok(())
}
}

View File

@ -0,0 +1,237 @@
use crate::frame::{self, Frame};
use bytes::{Buf, BytesMut};
use std::io::{self, Cursor};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter};
use tokio::net::TcpStream;
/// Send and receive `Frame` values from a remote peer.
///
/// When implementing networking protocols, a message on that protocol is
/// often composed of several smaller messages known as frames. The purpose of
/// `Connection` is to read and write frames on the underlying `TcpStream`.
///
/// To read frames, the `Connection` uses an internal buffer, which is filled
/// up until there are enough bytes to create a full frame. Once this happens,
/// the `Connection` creates the frame and returns it to the caller.
///
/// When sending frames, the frame is first encoded into the write buffer.
/// The contents of the write buffer are then written to the socket.
#[derive(Debug)]
pub struct Connection {
// The `TcpStream`. It is decorated with a `BufWriter`, which provides write
// level buffering. The `BufWriter` implementation provided by Tokio is
// sufficient for our needs.
stream: BufWriter<TcpStream>,
// The buffer for reading frames.
buffer: BytesMut,
}
impl Connection {
/// Create a new `Connection`, backed by `socket`. Read and write buffers
/// are initialized.
pub fn new(socket: TcpStream) -> Connection {
Connection {
stream: BufWriter::new(socket),
// Default to a 4KB read buffer. For the use case of mini redis,
// this is fine. However, real applications will want to tune this
// value to their specific use case. There is a high likelihood that
// a larger read buffer will work better.
buffer: BytesMut::with_capacity(4 * 1024),
}
}
/// Read a single `Frame` value from the underlying stream.
///
/// The function waits until it has retrieved enough data to parse a frame.
/// Any data remaining in the read buffer after the frame has been parsed is
/// kept there for the next call to `read_frame`.
///
/// # Returns
///
/// On success, the received frame is returned. If the `TcpStream`
/// is closed in a way that doesn't break a frame in half, it returns
/// `None`. Otherwise, an error is returned.
pub async fn read_frame(&mut self) -> crate::Result<Option<Frame>> {
loop {
// Attempt to parse a frame from the buffered data. If enough data
// has been buffered, the frame is returned.
if let Some(frame) = self.parse_frame()? {
return Ok(Some(frame));
}
// There is not enough buffered data to read a frame. Attempt to
// read more data from the socket.
//
// On success, the number of bytes is returned. `0` indicates "end
// of stream".
if 0 == self.stream.read_buf(&mut self.buffer).await? {
// The remote closed the connection. For this to be a clean
// shutdown, there should be no data in the read buffer. If
// there is, this means that the peer closed the socket while
// sending a frame.
if self.buffer.is_empty() {
return Ok(None);
} else {
let s = "connection reset by peer".into();
return Err(s);
}
}
}
}
/// Tries to parse a frame from the buffer. If the buffer contains enough
/// data, the frame is returned and the data removed from the buffer. If not
/// enough data has been buffered yet, `Ok(None)` is returned. If the
/// buffered data does not represent a valid frame, `Err` is returned.
fn parse_frame(&mut self) -> crate::Result<Option<Frame>> {
use frame::Error::Incomplete;
// Cursor is used to track the "current" location in the
// buffer. Cursor also implements `Buf` from the `bytes` crate
// which provides a number of helpful utilities for working
// with bytes.
let mut buf = Cursor::new(&self.buffer[..]);
// The first step is to check if enough data has been buffered to parse
// a single frame. This step is usually much faster than doing a full
// parse of the frame, and allows us to skip allocating data structures
// to hold the frame data unless we know the full frame has been
// received.
match Frame::check(&mut buf) {
Ok(_) => {
// The `check` function will have advanced the cursor until the
// end of the frame. Since the cursor had position set to zero
// before `Frame::check` was called, we obtain the length of the
// frame by checking the cursor position.
let len = buf.position() as usize;
// Reset the position to zero before passing the cursor to
// `Frame::parse`.
buf.set_position(0);
// Parse the frame from the buffer. This allocates the necessary
// structures to represent the frame and returns the frame
// value.
//
// If the encoded frame representation is invalid, an error is
// returned. This should terminate the **current** connection
// but should not impact any other connected client.
let frame = Frame::parse(&mut buf)?;
// Discard the parsed data from the read buffer.
//
// When `advance` is called on the read buffer, all of the data
// up to `len` is discarded. The details of how this works is
// left to `BytesMut`. This is often done by moving an internal
// cursor, but it may be done by reallocating and copying data.
self.buffer.advance(len);
// Return the parsed frame to the caller.
Ok(Some(frame))
}
// There is not enough data present in the read buffer to parse a
// single frame. We must wait for more data to be received from the
// socket. Reading from the socket will be done in the statement
// after this `match`.
//
// We do not want to return `Err` from here as this "error" is an
// expected runtime condition.
Err(Incomplete) => Ok(None),
// An error was encountered while parsing the frame. The connection
// is now in an invalid state. Returning `Err` from here will result
// in the connection being closed.
Err(e) => Err(e.into()),
}
}
/// Write a single `Frame` value to the underlying stream.
///
/// The `Frame` value is written to the socket using the various `write_*`
/// functions provided by `AsyncWrite`. Calling these functions directly on
/// a `TcpStream` is **not** advised, as this will result in a large number of
/// syscalls. However, it is fine to call these functions on a *buffered*
/// write stream. The data will be written to the buffer. Once the buffer is
/// full, it is flushed to the underlying socket.
pub async fn write_frame(&mut self, frame: &Frame) -> io::Result<()> {
// Arrays are encoded by encoding each entry. All other frame types are
// considered literals. For now, mini-redis is not able to encode
// recursive frame structures. See below for more details.
match frame {
Frame::Array(val) => {
// Encode the frame type prefix. For an array, it is `*`.
self.stream.write_u8(b'*').await?;
// Encode the length of the array.
self.write_decimal(val.len() as u64).await?;
// Iterate and encode each entry in the array.
for entry in &**val {
self.write_value(entry).await?;
}
}
// The frame type is a literal. Encode the value directly.
_ => self.write_value(frame).await?,
}
// Ensure the encoded frame is written to the socket. The calls above
// are to the buffered stream and writes. Calling `flush` writes the
// remaining contents of the buffer to the socket.
self.stream.flush().await
}
/// Write a frame literal to the stream
async fn write_value(&mut self, frame: &Frame) -> io::Result<()> {
match frame {
Frame::Simple(val) => {
self.stream.write_u8(b'+').await?;
self.stream.write_all(val.as_bytes()).await?;
self.stream.write_all(b"\r\n").await?;
}
Frame::Error(val) => {
self.stream.write_u8(b'-').await?;
self.stream.write_all(val.as_bytes()).await?;
self.stream.write_all(b"\r\n").await?;
}
Frame::Integer(val) => {
self.stream.write_u8(b':').await?;
self.write_decimal(*val).await?;
}
Frame::Null => {
self.stream.write_all(b"$-1\r\n").await?;
}
Frame::Bulk(val) => {
let len = val.len();
self.stream.write_u8(b'$').await?;
self.write_decimal(len as u64).await?;
self.stream.write_all(val).await?;
self.stream.write_all(b"\r\n").await?;
}
// Encoding an `Array` from within a value cannot be done using a
// recursive strategy. In general, async fns do not support
// recursion. Mini-redis has not needed to encode nested arrays yet,
// so for now it is skipped.
Frame::Array(_val) => unreachable!(),
}
Ok(())
}
/// Write a decimal frame to the stream
async fn write_decimal(&mut self, val: u64) -> io::Result<()> {
use std::io::Write;
// Convert the value to a string
let mut buf = [0u8; 20];
let mut buf = Cursor::new(&mut buf[..]);
write!(&mut buf, "{}", val)?;
let pos = buf.position() as usize;
self.stream.write_all(&buf.get_ref()[..pos]).await?;
self.stream.write_all(b"\r\n").await?;
Ok(())
}
}

View File

@ -0,0 +1,378 @@
use tokio::sync::{broadcast, Notify};
use tokio::time::{self, Duration, Instant};
use bytes::Bytes;
use std::collections::{BTreeMap, HashMap};
use std::sync::{Arc, Mutex};
use tracing::debug;
/// A wrapper around a `Db` instance. This exists to allow orderly cleanup
/// of the `Db` by signalling the background purge task to shut down when
/// this struct is dropped.
#[derive(Debug)]
pub(crate) struct DbDropGuard {
/// The `Db` instance that will be shut down when this `DbHolder` struct
/// is dropped.
db: Db,
}
/// Server state shared across all connections.
///
/// `Db` contains a `HashMap` storing the key/value data and all
/// `broadcast::Sender` values for active pub/sub channels.
///
/// A `Db` instance is a handle to shared state. Cloning `Db` is shallow and
/// only incurs an atomic ref count increment.
///
/// When a `Db` value is created, a background task is spawned. This task is
/// used to expire values after the requested duration has elapsed. The task
/// runs until all instances of `Db` are dropped, at which point the task
/// terminates.
#[derive(Debug, Clone)]
pub(crate) struct Db {
/// Handle to shared state. The background task will also have an
/// `Arc<Shared>`.
shared: Arc<Shared>,
}
#[derive(Debug)]
struct Shared {
/// The shared state is guarded by a mutex. This is a `std::sync::Mutex` and
/// not a Tokio mutex. This is because there are no asynchronous operations
/// being performed while holding the mutex. Additionally, the critical
/// sections are very small.
///
/// A Tokio mutex is mostly intended to be used when locks need to be held
/// across `.await` yield points. All other cases are **usually** best
/// served by a std mutex. If the critical section does not include any
/// async operations but is long (CPU intensive or performing blocking
/// operations), then the entire operation, including waiting for the mutex,
/// is considered a "blocking" operation and `tokio::task::spawn_blocking`
/// should be used.
state: Mutex<State>,
/// Notifies the background task handling entry expiration. The background
/// task waits on this to be notified, then checks for expired values or the
/// shutdown signal.
background_task: Notify,
}
#[derive(Debug)]
struct State {
/// The key-value data. We are not trying to do anything fancy so a
/// `std::collections::HashMap` works fine.
entries: HashMap<String, Entry>,
/// The pub/sub key-space. Redis uses a **separate** key space for key-value
/// and pub/sub. `mini-redis` handles this by using a separate `HashMap`.
pub_sub: HashMap<String, broadcast::Sender<Bytes>>,
/// Tracks key TTLs.
///
/// A `BTreeMap` is used to maintain expirations sorted by when they expire.
/// This allows the background task to iterate this map to find the value
/// expiring next.
///
/// While highly unlikely, it is possible for more than one expiration to be
/// created for the same instant. Because of this, the `Instant` is
/// insufficient for the key. A unique expiration identifier (`u64`) is used
/// to break these ties.
expirations: BTreeMap<(Instant, u64), String>,
/// Identifier to use for the next expiration. Each expiration is associated
/// with a unique identifier. See above for why.
next_id: u64,
/// True when the Db instance is shutting down. This happens when all `Db`
/// values drop. Setting this to `true` signals to the background task to
/// exit.
shutdown: bool,
}
/// Entry in the key-value store
#[derive(Debug)]
struct Entry {
/// Uniquely identifies this entry.
id: u64,
/// Stored data
data: Bytes,
/// Instant at which the entry expires and should be removed from the
/// database.
expires_at: Option<Instant>,
}
impl DbDropGuard {
/// Create a new `DbHolder`, wrapping a `Db` instance. When this is dropped
/// the `Db`'s purge task will be shut down.
pub(crate) fn new() -> DbDropGuard {
DbDropGuard { db: Db::new() }
}
/// Get the shared database. Internally, this is an
/// `Arc`, so a clone only increments the ref count.
pub(crate) fn db(&self) -> Db {
self.db.clone()
}
}
impl Drop for DbDropGuard {
fn drop(&mut self) {
// Signal the 'Db' instance to shut down the task that purges expired keys
self.db.shutdown_purge_task();
}
}
impl Db {
/// Create a new, empty, `Db` instance. Allocates shared state and spawns a
/// background task to manage key expiration.
pub(crate) fn new() -> Db {
let shared = Arc::new(Shared {
state: Mutex::new(State {
entries: HashMap::new(),
pub_sub: HashMap::new(),
expirations: BTreeMap::new(),
next_id: 0,
shutdown: false,
}),
background_task: Notify::new(),
});
// Start the background task.
tokio::spawn(purge_expired_tasks(shared.clone()));
Db { shared }
}
/// Get the value associated with a key.
///
/// Returns `None` if there is no value associated with the key. This may be
/// due to never having assigned a value to the key or a previously assigned
/// value expired.
pub(crate) fn get(&self, key: &str) -> Option<Bytes> {
// Acquire the lock, get the entry and clone the value.
//
// Because data is stored using `Bytes`, a clone here is a shallow
// clone. Data is not copied.
let state = self.shared.state.lock().unwrap();
state.entries.get(key).map(|entry| entry.data.clone())
}
/// Set the value associated with a key along with an optional expiration
/// Duration.
///
/// If a value is already associated with the key, it is removed.
pub(crate) fn set(&self, key: String, value: Bytes, expire: Option<Duration>) {
let mut state = self.shared.state.lock().unwrap();
// Get and increment the next insertion ID. Guarded by the lock, this
// ensures a unique identifier is associated with each `set` operation.
let id = state.next_id;
state.next_id += 1;
// If this `set` becomes the key that expires **next**, the background
// task needs to be notified so it can update its state.
//
// Whether or not the task needs to be notified is computed during the
// `set` routine.
let mut notify = false;
let expires_at = expire.map(|duration| {
// `Instant` at which the key expires.
let when = Instant::now() + duration;
// Only notify the worker task if the newly inserted expiration is the
// **next** key to evict. In this case, the worker needs to be woken up
// to update its state.
notify = state
.next_expiration()
.map(|expiration| expiration > when)
.unwrap_or(true);
// Track the expiration.
state.expirations.insert((when, id), key.clone());
when
});
// Insert the entry into the `HashMap`.
let prev = state.entries.insert(
key,
Entry {
id,
data: value,
expires_at,
},
);
// If there was a value previously associated with the key **and** it
// had an expiration time. The associated entry in the `expirations` map
// must also be removed. This avoids leaking data.
if let Some(prev) = prev {
if let Some(when) = prev.expires_at {
// clear expiration
state.expirations.remove(&(when, prev.id));
}
}
// Release the mutex before notifying the background task. This helps
// reduce contention by avoiding the background task waking up only to
// be unable to acquire the mutex due to this function still holding it.
drop(state);
if notify {
// Finally, only notify the background task if it needs to update
// its state to reflect a new expiration.
self.shared.background_task.notify_one();
}
}
/// Returns a `Receiver` for the requested channel.
///
/// The returned `Receiver` is used to receive values broadcast by `PUBLISH`
/// commands.
pub(crate) fn subscribe(&self, key: String) -> broadcast::Receiver<Bytes> {
use std::collections::hash_map::Entry;
// Acquire the mutex
let mut state = self.shared.state.lock().unwrap();
// If there is no entry for the requested channel, then create a new
// broadcast channel and associate it with the key. If one already
// exists, return an associated receiver.
match state.pub_sub.entry(key) {
Entry::Occupied(e) => e.get().subscribe(),
Entry::Vacant(e) => {
// No broadcast channel exists yet, so create one.
//
// The channel is created with a capacity of `1024` messages. A
// message is stored in the channel until **all** subscribers
// have seen it. This means that a slow subscriber could result
// in messages being held indefinitely.
//
// When the channel's capacity fills up, publishing will result
// in old messages being dropped. This prevents slow consumers
// from blocking the entire system.
let (tx, rx) = broadcast::channel(1024);
e.insert(tx);
rx
}
}
}
/// Publish a message to the channel. Returns the number of subscribers
/// listening on the channel.
pub(crate) fn publish(&self, key: &str, value: Bytes) -> usize {
let state = self.shared.state.lock().unwrap();
state
.pub_sub
.get(key)
// On a successful message send on the broadcast channel, the number
// of subscribers is returned. An error indicates there are no
// receivers, in which case, `0` should be returned.
.map(|tx| tx.send(value).unwrap_or(0))
// If there is no entry for the channel key, then there are no
// subscribers. In this case, return `0`.
.unwrap_or(0)
}
/// Signals the purge background task to shut down. This is called by the
/// `DbShutdown`s `Drop` implementation.
fn shutdown_purge_task(&self) {
// The background task must be signaled to shut down. This is done by
// setting `State::shutdown` to `true` and signalling the task.
let mut state = self.shared.state.lock().unwrap();
state.shutdown = true;
// Drop the lock before signalling the background task. This helps
// reduce lock contention by ensuring the background task doesn't
// wake up only to be unable to acquire the mutex.
drop(state);
self.shared.background_task.notify_one();
}
}
impl Shared {
/// Purge all expired keys and return the `Instant` at which the **next**
/// key will expire. The background task will sleep until this instant.
fn purge_expired_keys(&self) -> Option<Instant> {
let mut state = self.state.lock().unwrap();
if state.shutdown {
// The database is shutting down. All handles to the shared state
// have dropped. The background task should exit.
return None;
}
// This is needed to make the borrow checker happy. In short, `lock()`
// returns a `MutexGuard` and not a `&mut State`. The borrow checker is
// not able to see "through" the mutex guard and determine that it is
// safe to access both `state.expirations` and `state.entries` mutably,
// so we get a "real" mutable reference to `State` outside of the loop.
let state = &mut *state;
// Find all keys scheduled to expire **before** now.
let now = Instant::now();
while let Some((&(when, id), key)) = state.expirations.iter().next() {
if when > now {
// Done purging, `when` is the instant at which the next key
// expires. The worker task will wait until this instant.
return Some(when);
}
// The key expired, remove it
state.entries.remove(key);
state.expirations.remove(&(when, id));
}
None
}
/// Returns `true` if the database is shutting down
///
/// The `shutdown` flag is set when all `Db` values have dropped, indicating
/// that the shared state can no longer be accessed.
fn is_shutdown(&self) -> bool {
self.state.lock().unwrap().shutdown
}
}
impl State {
fn next_expiration(&self) -> Option<Instant> {
self.expirations
.keys()
.next()
.map(|expiration| expiration.0)
}
}
/// Routine executed by the background task.
///
/// Wait to be notified. On notification, purge any expired keys from the shared
/// state handle. If `shutdown` is set, terminate the task.
async fn purge_expired_tasks(shared: Arc<Shared>) {
// If the shutdown flag is set, then the task should exit.
while !shared.is_shutdown() {
// Purge all keys that are expired. The function returns the instant at
// which the **next** key will expire. The worker should wait until the
// instant has passed then purge again.
if let Some(when) = shared.purge_expired_keys() {
// Wait until the next key expires **or** until the background task
// is notified. If the task is notified, then it must reload its
// state as new keys have been set to expire early. This is done by
// looping.
tokio::select! {
_ = time::sleep_until(when) => {}
_ = shared.background_task.notified() => {}
}
} else {
// There are no keys expiring in the future. Wait until the task is
// notified.
shared.background_task.notified().await;
}
}
debug!("Purge background task shut down")
}

View File

@ -0,0 +1,300 @@
//! Provides a type representing a Redis protocol frame as well as utilities for
//! parsing frames from a byte array.
use bytes::{Buf, Bytes};
use std::convert::TryInto;
use std::fmt;
use std::io::Cursor;
use std::num::TryFromIntError;
use std::string::FromUtf8Error;
/// A frame in the Redis protocol.
#[derive(Clone, Debug)]
pub enum Frame {
Simple(String),
Error(String),
Integer(u64),
Bulk(Bytes),
Null,
Array(Vec<Frame>),
}
#[derive(Debug)]
pub enum Error {
/// Not enough data is available to parse a message
Incomplete,
/// Invalid message encoding
Other(crate::Error),
}
impl Frame {
/// Returns an empty array
pub(crate) fn array() -> Frame {
Frame::Array(vec![])
}
/// Push a "bulk" frame into the array. `self` must be an Array frame.
///
/// # Panics
///
/// panics if `self` is not an array
pub(crate) fn push_bulk(&mut self, bytes: Bytes) {
match self {
Frame::Array(vec) => {
vec.push(Frame::Bulk(bytes));
}
_ => panic!("not an array frame"),
}
}
/// Push an "integer" frame into the array. `self` must be an Array frame.
///
/// # Panics
///
/// panics if `self` is not an array
pub(crate) fn push_int(&mut self, value: u64) {
match self {
Frame::Array(vec) => {
vec.push(Frame::Integer(value));
}
_ => panic!("not an array frame"),
}
}
/// Checks if an entire message can be decoded from `src`
pub fn check(src: &mut Cursor<&[u8]>) -> Result<(), Error> {
match get_u8(src)? {
b'+' => {
get_line(src)?;
Ok(())
}
b'-' => {
get_line(src)?;
Ok(())
}
b':' => {
let _ = get_decimal(src)?;
Ok(())
}
b'$' => {
if b'-' == peek_u8(src)? {
// Skip '-1\r\n'
skip(src, 4)
} else {
// Read the bulk string
let len: usize = get_decimal(src)?.try_into()?;
// skip that number of bytes + 2 (\r\n).
skip(src, len + 2)
}
}
b'*' => {
let len = get_decimal(src)?;
for _ in 0..len {
Frame::check(src)?;
}
Ok(())
}
actual => Err(format!("protocol error; invalid frame type byte `{}`", actual).into()),
}
}
/// The message has already been validated with `check`.
pub fn parse(src: &mut Cursor<&[u8]>) -> Result<Frame, Error> {
match get_u8(src)? {
b'+' => {
// Read the line and convert it to `Vec<u8>`
let line = get_line(src)?.to_vec();
// Convert the line to a String
let string = String::from_utf8(line)?;
Ok(Frame::Simple(string))
}
b'-' => {
// Read the line and convert it to `Vec<u8>`
let line = get_line(src)?.to_vec();
// Convert the line to a String
let string = String::from_utf8(line)?;
Ok(Frame::Error(string))
}
b':' => {
let len = get_decimal(src)?;
Ok(Frame::Integer(len))
}
b'$' => {
if b'-' == peek_u8(src)? {
let line = get_line(src)?;
if line != b"-1" {
return Err("protocol error; invalid frame format".into());
}
Ok(Frame::Null)
} else {
// Read the bulk string
let len = get_decimal(src)?.try_into()?;
let n = len + 2;
if src.remaining() < n {
return Err(Error::Incomplete);
}
let data = Bytes::copy_from_slice(&src.chunk()[..len]);
// skip that number of bytes + 2 (\r\n).
skip(src, n)?;
Ok(Frame::Bulk(data))
}
}
b'*' => {
let len = get_decimal(src)?.try_into()?;
let mut out = Vec::with_capacity(len);
for _ in 0..len {
out.push(Frame::parse(src)?);
}
Ok(Frame::Array(out))
}
_ => unimplemented!(),
}
}
/// Converts the frame to an "unexpected frame" error
pub(crate) fn to_error(&self) -> crate::Error {
format!("unexpected frame: {}", self).into()
}
}
impl PartialEq<&str> for Frame {
fn eq(&self, other: &&str) -> bool {
match self {
Frame::Simple(s) => s.eq(other),
Frame::Bulk(s) => s.eq(other),
_ => false,
}
}
}
impl fmt::Display for Frame {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
use std::str;
match self {
Frame::Simple(response) => response.fmt(fmt),
Frame::Error(msg) => write!(fmt, "error: {}", msg),
Frame::Integer(num) => num.fmt(fmt),
Frame::Bulk(msg) => match str::from_utf8(msg) {
Ok(string) => string.fmt(fmt),
Err(_) => write!(fmt, "{:?}", msg),
},
Frame::Null => "(nil)".fmt(fmt),
Frame::Array(parts) => {
for (i, part) in parts.iter().enumerate() {
if i > 0 {
write!(fmt, " ")?;
part.fmt(fmt)?;
}
}
Ok(())
}
}
}
}
fn peek_u8(src: &mut Cursor<&[u8]>) -> Result<u8, Error> {
if !src.has_remaining() {
return Err(Error::Incomplete);
}
Ok(src.chunk()[0])
}
fn get_u8(src: &mut Cursor<&[u8]>) -> Result<u8, Error> {
if !src.has_remaining() {
return Err(Error::Incomplete);
}
Ok(src.get_u8())
}
fn skip(src: &mut Cursor<&[u8]>, n: usize) -> Result<(), Error> {
if src.remaining() < n {
return Err(Error::Incomplete);
}
src.advance(n);
Ok(())
}
/// Read a new-line terminated decimal
fn get_decimal(src: &mut Cursor<&[u8]>) -> Result<u64, Error> {
use atoi::atoi;
let line = get_line(src)?;
atoi::<u64>(line).ok_or_else(|| "protocol error; invalid frame format".into())
}
/// Find a line
fn get_line<'a>(src: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], Error> {
// Scan the bytes directly
let start = src.position() as usize;
// Scan to the second to last byte
let end = src.get_ref().len() - 1;
for i in start..end {
if src.get_ref()[i] == b'\r' && src.get_ref()[i + 1] == b'\n' {
// We found a line, update the position to be *after* the \n
src.set_position((i + 2) as u64);
// Return the line
return Ok(&src.get_ref()[start..i]);
}
}
Err(Error::Incomplete)
}
impl From<String> for Error {
fn from(src: String) -> Error {
Error::Other(src.into())
}
}
impl From<&str> for Error {
fn from(src: &str) -> Error {
src.to_string().into()
}
}
impl From<FromUtf8Error> for Error {
fn from(_src: FromUtf8Error) -> Error {
"protocol error; invalid frame format".into()
}
}
impl From<TryFromIntError> for Error {
fn from(_src: TryFromIntError) -> Error {
"protocol error; invalid frame format".into()
}
}
impl std::error::Error for Error {}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::Incomplete => "stream ended early".fmt(fmt),
Error::Other(err) => err.fmt(fmt),
}
}
}

View File

@ -0,0 +1,76 @@
//! A minimal (i.e. very incomplete) implementation of a Redis server and
//! client.
//!
//! The purpose of this project is to provide a larger example of an
//! asynchronous Rust project built with Tokio. Do not attempt to run this in
//! production... seriously.
//!
//! # Layout
//!
//! The library is structured such that it can be used with guides. There are
//! modules that are public that probably would not be public in a "real" redis
//! client library.
//!
//! The major components are:
//!
//! * `server`: Redis server implementation. Includes a single `run` function
//! that takes a `TcpListener` and starts accepting redis client connections.
//!
//! * `client`: an asynchronous Redis client implementation. Demonstrates how to
//! build clients with Tokio.
//!
//! * `cmd`: implementations of the supported Redis commands.
//!
//! * `frame`: represents a single Redis protocol frame. A frame is used as an
//! intermediate representation between a "command" and the byte
//! representation.
pub mod blocking_client;
pub mod client;
pub mod cmd;
pub use cmd::Command;
mod connection;
pub use connection::Connection;
pub mod frame;
pub use frame::Frame;
mod db;
use db::Db;
use db::DbDropGuard;
mod parse;
use parse::{Parse, ParseError};
pub mod server;
mod buffer;
pub use buffer::{buffer, Buffer};
mod shutdown;
use shutdown::Shutdown;
/// Default port that a redis server listens on.
///
/// Used if no port is specified.
pub const DEFAULT_PORT: &str = "6379";
/// Error returned by most functions.
///
/// When writing a real application, one might want to consider a specialized
/// error handling crate or defining an error type as an `enum` of causes.
/// However, for our example, using a boxed `std::error::Error` is sufficient.
///
/// For performance reasons, boxing is avoided in any hot path. For example, in
/// `parse`, a custom error `enum` is defined. This is because the error is hit
/// and handled during normal execution when a partial frame is received on a
/// socket. `std::error::Error` is implemented for `parse::Error` which allows
/// it to be converted to `Box<dyn std::error::Error>`.
pub type Error = Box<dyn std::error::Error + Send + Sync>;
/// A specialized `Result` type for mini-redis operations.
///
/// This is defined as a convenience.
pub type Result<T> = std::result::Result<T, Error>;

View File

@ -0,0 +1,152 @@
use crate::Frame;
use bytes::Bytes;
use std::{fmt, str, vec};
/// Utility for parsing a command
///
/// Commands are represented as array frames. Each entry in the frame is a
/// "token". A `Parse` is initialized with the array frame and provides a
/// cursor-like API. Each command struct includes a `parse_frame` method that
/// uses a `Parse` to extract its fields.
#[derive(Debug)]
pub(crate) struct Parse {
/// Array frame iterator.
parts: vec::IntoIter<Frame>,
}
/// Error encountered while parsing a frame.
///
/// Only `EndOfStream` errors are handled at runtime. All other errors result in
/// the connection being terminated.
#[derive(Debug)]
pub(crate) enum ParseError {
/// Attempting to extract a value failed due to the frame being fully
/// consumed.
EndOfStream,
/// All other errors
Other(crate::Error),
}
impl Parse {
/// Create a new `Parse` to parse the contents of `frame`.
///
/// Returns `Err` if `frame` is not an array frame.
pub(crate) fn new(frame: Frame) -> Result<Parse, ParseError> {
let array = match frame {
Frame::Array(array) => array,
frame => return Err(format!("protocol error; expected array, got {:?}", frame).into()),
};
Ok(Parse {
parts: array.into_iter(),
})
}
/// Return the next entry. Array frames are arrays of frames, so the next
/// entry is a frame.
fn next(&mut self) -> Result<Frame, ParseError> {
self
.parts
.next()
.ok_or(ParseError::EndOfStream)
}
/// Return the next entry as a string.
///
/// If the next entry cannot be represented as a String, then an error is returned.
pub(crate) fn next_string(&mut self) -> Result<String, ParseError> {
match self.next()? {
// Both `Simple` and `Bulk` representation may be strings. Strings
// are parsed to UTF-8.
//
// While errors are stored as strings, they are considered separate
// types.
Frame::Simple(s) => Ok(s),
Frame::Bulk(data) => str::from_utf8(&data[..])
.map(|s| s.to_string())
.map_err(|_| "protocol error; invalid string".into()),
frame => Err(format!(
"protocol error; expected simple frame or bulk frame, got {:?}",
frame
)
.into()),
}
}
/// Return the next entry as raw bytes.
///
/// If the next entry cannot be represented as raw bytes, an error is
/// returned.
pub(crate) fn next_bytes(&mut self) -> Result<Bytes, ParseError> {
match self.next()? {
// Both `Simple` and `Bulk` representation may be raw bytes.
//
// Although errors are stored as strings and could be represented as
// raw bytes, they are considered separate types.
Frame::Simple(s) => Ok(Bytes::from(s.into_bytes())),
Frame::Bulk(data) => Ok(data),
frame => Err(format!(
"protocol error; expected simple frame or bulk frame, got {:?}",
frame
)
.into()),
}
}
/// Return the next entry as an integer.
///
/// This includes `Simple`, `Bulk`, and `Integer` frame types. `Simple` and
/// `Bulk` frame types are parsed.
///
/// If the next entry cannot be represented as an integer, then an error is
/// returned.
pub(crate) fn next_int(&mut self) -> Result<u64, ParseError> {
use atoi::atoi;
const MSG: &str = "protocol error; invalid number";
match self.next()? {
// An integer frame type is already stored as an integer.
Frame::Integer(v) => Ok(v),
// Simple and bulk frames must be parsed as integers. If the parsing
// fails, an error is returned.
Frame::Simple(data) => atoi::<u64>(data.as_bytes()).ok_or_else(|| MSG.into()),
Frame::Bulk(data) => atoi::<u64>(&data).ok_or_else(|| MSG.into()),
frame => Err(format!("protocol error; expected int frame but got {:?}", frame).into()),
}
}
/// Ensure there are no more entries in the array
pub(crate) fn finish(&mut self) -> Result<(), ParseError> {
if self.parts.next().is_none() {
Ok(())
} else {
Err("protocol error; expected end of frame, but there was more".into())
}
}
}
impl From<String> for ParseError {
fn from(src: String) -> ParseError {
ParseError::Other(src.into())
}
}
impl From<&str> for ParseError {
fn from(src: &str) -> ParseError {
src.to_string().into()
}
}
impl fmt::Display for ParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ParseError::EndOfStream => "protocol error; unexpected end of stream".fmt(f),
ParseError::Other(err) => err.fmt(f),
}
}
}
impl std::error::Error for ParseError {}

View File

@ -0,0 +1,399 @@
//! Minimal Redis server implementation
//!
//! Provides an async `run` function that listens for inbound connections,
//! spawning a task per connection.
use crate::{Command, Connection, Db, DbDropGuard, Shutdown};
use std::future::Future;
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{broadcast, mpsc, Semaphore};
use tokio::time::{self, Duration};
use tracing::{debug, error, info, instrument};
/// Server listener state. Created in the `run` call. It includes a `run` method
/// which performs the TCP listening and initialization of per-connection state.
#[derive(Debug)]
struct Listener {
/// Shared database handle.
///
/// Contains the key / value store as well as the broadcast channels for
/// pub/sub.
///
/// This holds a wrapper around an `Arc`. The internal `Db` can be
/// retrieved and passed into the per connection state (`Handler`).
db_holder: DbDropGuard,
/// TCP listener supplied by the `run` caller.
listener: TcpListener,
/// Limit the max number of connections.
///
/// A `Semaphore` is used to limit the max number of connections. Before
/// attempting to accept a new connection, a permit is acquired from the
/// semaphore. If none are available, the listener waits for one.
///
/// When handlers complete processing a connection, the permit is returned
/// to the semaphore.
limit_connections: Arc<Semaphore>,
/// Broadcasts a shutdown signal to all active connections.
///
/// The initial `shutdown` trigger is provided by the `run` caller. The
/// server is responsible for gracefully shutting down active connections.
/// When a connection task is spawned, it is passed a broadcast receiver
/// handle. When a graceful shutdown is initiated, a `()` value is sent via
/// the broadcast::Sender. Each active connection receives it, reaches a
/// safe terminal state, and completes the task.
notify_shutdown: broadcast::Sender<()>,
/// Used as part of the graceful shutdown process to wait for client
/// connections to complete processing.
///
/// Tokio channels are closed once all `Sender` handles go out of scope.
/// When a channel is closed, the receiver receives `None`. This is
/// leveraged to detect all connection handlers completing. When a
/// connection handler is initialized, it is assigned a clone of
/// `shutdown_complete_tx`. When the listener shuts down, it drops the
/// sender held by this `shutdown_complete_tx` field. Once all handler tasks
/// complete, all clones of the `Sender` are also dropped. This results in
/// `shutdown_complete_rx.recv()` completing with `None`. At this point, it
/// is safe to exit the server process.
shutdown_complete_rx: mpsc::Receiver<()>,
shutdown_complete_tx: mpsc::Sender<()>,
}
/// Per-connection handler. Reads requests from `connection` and applies the
/// commands to `db`.
#[derive(Debug)]
struct Handler {
/// Shared database handle.
///
/// When a command is received from `connection`, it is applied with `db`.
/// The implementation of the command is in the `cmd` module. Each command
/// will need to interact with `db` in order to complete the work.
db: Db,
/// The TCP connection decorated with the redis protocol encoder / decoder
/// implemented using a buffered `TcpStream`.
///
/// When `Listener` receives an inbound connection, the `TcpStream` is
/// passed to `Connection::new`, which initializes the associated buffers.
/// `Connection` allows the handler to operate at the "frame" level and keep
/// the byte level protocol parsing details encapsulated in `Connection`.
connection: Connection,
/// Max connection semaphore.
///
/// When the handler is dropped, a permit is returned to this semaphore. If
/// the listener is waiting for connections to close, it will be notified of
/// the newly available permit and resume accepting connections.
limit_connections: Arc<Semaphore>,
/// Listen for shutdown notifications.
///
/// A wrapper around the `broadcast::Receiver` paired with the sender in
/// `Listener`. The connection handler processes requests from the
/// connection until the peer disconnects **or** a shutdown notification is
/// received from `shutdown`. In the latter case, any in-flight work being
/// processed for the peer is continued until it reaches a safe state, at
/// which point the connection is terminated.
shutdown: Shutdown,
/// Not used directly. Instead, when `Handler` is dropped...?
_shutdown_complete: mpsc::Sender<()>,
}
/// Maximum number of concurrent connections the redis server will accept.
///
/// When this limit is reached, the server will stop accepting connections until
/// an active connection terminates.
///
/// A real application will want to make this value configurable, but for this
/// example, it is hard coded.
///
/// This is also set to a pretty low value to discourage using this in
/// production (you'd think that all the disclaimers would make it obvious that
/// this is not a serious project... but I thought that about mini-http as
/// well).
const MAX_CONNECTIONS: usize = 250;
/// Run the mini-redis server.
///
/// Accepts connections from the supplied listener. For each inbound connection,
/// a task is spawned to handle that connection. The server runs until the
/// `shutdown` future completes, at which point the server shuts down
/// gracefully.
///
/// `tokio::signal::ctrl_c()` can be used as the `shutdown` argument. This will
/// listen for a SIGINT signal.
pub async fn run(listener: TcpListener, shutdown: impl Future) {
// When the provided `shutdown` future completes, we must send a shutdown
// message to all active connections. We use a broadcast channel for this
// purpose. The call below ignores the receiver of the broadcast pair, and when
// a receiver is needed, the subscribe() method on the sender is used to create
// one.
let (notify_shutdown, _) = broadcast::channel(1);
let (shutdown_complete_tx, shutdown_complete_rx) = mpsc::channel(1);
// Initialize the listener state
let mut server = Listener {
listener,
db_holder: DbDropGuard::new(),
limit_connections: Arc::new(Semaphore::new(MAX_CONNECTIONS)),
notify_shutdown,
shutdown_complete_tx,
shutdown_complete_rx,
};
// Concurrently run the server and listen for the `shutdown` signal. The
// server task runs until an error is encountered, so under normal
// circumstances, this `select!` statement runs until the `shutdown` signal
// is received.
//
// `select!` statements are written in the form of:
//
// ```
// <result of async op> = <async op> => <step to perform with result>
// ```
//
// All `<async op>` statements are executed concurrently. Once the **first**
// op completes, its associated `<step to perform with result>` is
// performed.
//
// The `select! macro is a foundational building block for writing
// asynchronous Rust. See the API docs for more details:
//
// https://docs.rs/tokio/*/tokio/macro.select.html
tokio::select! {
res = server.run() => {
// If an error is received here, accepting connections from the TCP
// listener failed multiple times and the server is giving up and
// shutting down.
//
// Errors encountered when handling individual connections do not
// bubble up to this point.
if let Err(err) = res {
error!(cause = %err, "failed to accept");
}
}
_ = shutdown => {
// The shutdown signal has been received.
info!("shutting down");
}
}
// Extract the `shutdown_complete` receiver and transmitter
// explicitly drop `shutdown_transmitter`. This is important, as the
// `.await` below would otherwise never complete.
let Listener {
mut shutdown_complete_rx,
shutdown_complete_tx,
notify_shutdown,
..
} = server;
// When `notify_shutdown` is dropped, all tasks which have `subscribe`d will
// receive the shutdown signal and can exit
drop(notify_shutdown);
// Drop final `Sender` so the `Receiver` below can complete
drop(shutdown_complete_tx);
// Wait for all active connections to finish processing. As the `Sender`
// handle held by the listener has been dropped above, the only remaining
// `Sender` instances are held by connection handler tasks. When those drop,
// the `mpsc` channel will close and `recv()` will return `None`.
let _ = shutdown_complete_rx.recv().await;
}
impl Listener {
/// Run the server
///
/// Listen for inbound connections. For each inbound connection, spawn a
/// task to process that connection.
///
/// # Errors
///
/// Returns `Err` if accepting returns an error. This can happen for a
/// number reasons that resolve over time. For example, if the underlying
/// operating system has reached an internal limit for max number of
/// sockets, accept will fail.
///
/// The process is not able to detect when a transient error resolves
/// itself. One strategy for handling this is to implement a back off
/// strategy, which is what we do here.
async fn run(&mut self) -> crate::Result<()> {
info!("accepting inbound connections");
loop {
// Wait for a permit to become available
//
// `acquire` returns a permit that is bound via a lifetime to the
// semaphore. When the permit value is dropped, it is automatically
// returned to the semaphore. This is convenient in many cases.
// However, in this case, the permit must be returned in a different
// task than it is acquired in (the handler task). To do this, we
// "forget" the permit, which drops the permit value **without**
// incrementing the semaphore's permits. Then, in the handler task
// we manually add a new permit when processing completes.
//
// `acquire()` returns `Err` when the semaphore has been closed. We
// don't ever close the sempahore, so `unwrap()` is safe.
self.limit_connections.acquire().await.unwrap().forget();
// Accept a new socket. This will attempt to perform error handling.
// The `accept` method internally attempts to recover errors, so an
// error here is non-recoverable.
let socket = self.accept().await?;
// Create the necessary per-connection handler state.
let mut handler = Handler {
// Get a handle to the shared database.
db: self.db_holder.db(),
// Initialize the connection state. This allocates read/write
// buffers to perform redis protocol frame parsing.
connection: Connection::new(socket),
// The connection state needs a handle to the max connections
// semaphore. When the handler is done processing the
// connection, a permit is added back to the semaphore.
limit_connections: self.limit_connections.clone(),
// Receive shutdown notifications.
shutdown: Shutdown::new(self.notify_shutdown.subscribe()),
// Notifies the receiver half once all clones are
// dropped.
_shutdown_complete: self.shutdown_complete_tx.clone(),
};
// Spawn a new task to process the connections. Tokio tasks are like
// asynchronous green threads and are executed concurrently.
tokio::spawn(async move {
// Process the connection. If an error is encountered, log it.
if let Err(err) = handler.run().await {
error!(cause = ?err, "connection error");
}
});
}
}
/// Accept an inbound connection.
///
/// Errors are handled by backing off and retrying. An exponential backoff
/// strategy is used. After the first failure, the task waits for 1 second.
/// After the second failure, the task waits for 2 seconds. Each subsequent
/// failure doubles the wait time. If accepting fails on the 6th try after
/// waiting for 64 seconds, then this function returns with an error.
async fn accept(&mut self) -> crate::Result<TcpStream> {
let mut backoff = 1;
// Try to accept a few times
loop {
// Perform the accept operation. If a socket is successfully
// accepted, return it. Otherwise, save the error.
match self.listener.accept().await {
Ok((socket, _)) => return Ok(socket),
Err(err) => {
if backoff > 64 {
// Accept has failed too many times. Return the error.
return Err(err.into());
}
}
}
// Pause execution until the back off period elapses.
time::sleep(Duration::from_secs(backoff)).await;
// Double the back off
backoff *= 2;
}
}
}
impl Handler {
/// Process a single connection.
///
/// Request frames are read from the socket and processed. Responses are
/// written back to the socket.
///
/// Currently, pipelining is not implemented. Pipelining is the ability to
/// process more than one request concurrently per connection without
/// interleaving frames. See for more details:
/// https://redis.io/topics/pipelining
///
/// When the shutdown signal is received, the connection is processed until
/// it reaches a safe state, at which point it is terminated.
#[instrument(skip(self))]
async fn run(&mut self) -> crate::Result<()> {
// As long as the shutdown signal has not been received, try to read a
// new request frame.
while !self.shutdown.is_shutdown() {
// While reading a request frame, also listen for the shutdown
// signal.
let maybe_frame = tokio::select! {
res = self.connection.read_frame() => res?,
_ = self.shutdown.recv() => {
// If a shutdown signal is received, return from `run`.
// This will result in the task terminating.
return Ok(());
}
};
// If `None` is returned from `read_frame()` then the peer closed
// the socket. There is no further work to do and the task can be
// terminated.
let frame = match maybe_frame {
Some(frame) => frame,
None => return Ok(()),
};
// Convert the redis frame into a command struct. This returns an
// error if the frame is not a valid redis command or it is an
// unsupported command.
let cmd = Command::from_frame(frame)?;
// Logs the `cmd` object. The syntax here is a shorthand provided by
// the `tracing` crate. It can be thought of as similar to:
//
// ```
// debug!(cmd = format!("{:?}", cmd));
// ```
//
// `tracing` provides structured logging, so information is "logged"
// as key-value pairs.
debug!(?cmd);
// Perform the work needed to apply the command. This may mutate the
// database state as a result.
//
// The connection is passed into the apply function which allows the
// command to write response frames directly to the connection. In
// the case of pub/sub, multiple frames may be send back to the
// peer.
cmd.apply(&self.db, &mut self.connection, &mut self.shutdown)
.await?;
}
Ok(())
}
}
impl Drop for Handler {
fn drop(&mut self) {
// Add a permit back to the semaphore.
//
// Doing so unblocks the listener if the max number of
// connections has been reached.
//
// This is done in a `Drop` implementation in order to guarantee that
// the permit is added even if the task handling the connection panics.
// If `add_permit` was called at the end of the `run` function and some
// bug causes a panic. The permit would never be returned to the
// semaphore.
self.limit_connections.add_permits(1);
}
}

View File

@ -0,0 +1,49 @@
use tokio::sync::broadcast;
/// Listens for the server shutdown signal.
///
/// Shutdown is signalled using a `broadcast::Receiver`. Only a single value is
/// ever sent. Once a value has been sent via the broadcast channel, the server
/// should shutdown.
///
/// The `Shutdown` struct listens for the signal and tracks that the signal has
/// been received. Callers may query for whether the shutdown signal has been
/// received or not.
#[derive(Debug)]
pub(crate) struct Shutdown {
/// `true` if the shutdown signal has been received
shutdown: bool,
/// The receive half of the channel used to listen for shutdown.
notify: broadcast::Receiver<()>,
}
impl Shutdown {
/// Create a new `Shutdown` backed by the given `broadcast::Receiver`.
pub(crate) fn new(notify: broadcast::Receiver<()>) -> Shutdown {
Shutdown {
shutdown: false,
notify,
}
}
/// Returns `true` if the shutdown signal has been received.
pub(crate) fn is_shutdown(&self) -> bool {
self.shutdown
}
/// Receive the shutdown notice, waiting if necessary.
pub(crate) async fn recv(&mut self) {
// If the shutdown signal has already been received, then return
// immediately.
if self.shutdown {
return;
}
// Cannot receive a "lag error" as only one value is ever sent.
let _ = self.notify.recv().await;
// Remember that the signal has been received.
self.shutdown = true;
}
}