Added support for compressed packets.

I also had to bring packet serialization/deserialization
back out of PacketFormat so that I could make it a trait object
in the connection (before it was generic over PacketMap).
However, now that Connection abstracts over PacketFormat,
it actually reduced code duplication to do so.

I also reorganized the hierarchy a bit, moving packet formats
under the connection module and most other things under the
protocol module.
master
James T. Martin 2020-07-25 12:05:58 -07:00
parent a89562f9d6
commit 3701b59a11
Signed by: james
GPG Key ID: 4B7F3DA9351E577C
18 changed files with 296 additions and 143 deletions

37
Cargo.lock generated
View File

@ -1,5 +1,11 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
[[package]]
name = "adler"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee2a4ec343196209d6594e19543ae87a39f96d5534d7174822a3ad825dd6ed7e"
[[package]]
name = "ansi_term"
version = "0.11.0"
@ -65,6 +71,27 @@ dependencies = [
"yaml-rust",
]
[[package]]
name = "crc32fast"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba125de2af0df55319f41944744ad91c71113bf74a4646efff39afe1f6842db1"
dependencies = [
"cfg-if",
]
[[package]]
name = "flate2"
version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68c90b0fc46cf89d227cc78b40e494ff81287a92dd07631e5af0d06fe3cf885e"
dependencies = [
"cfg-if",
"crc32fast",
"libc",
"miniz_oxide",
]
[[package]]
name = "fuchsia-zircon"
version = "0.3.3"
@ -142,6 +169,15 @@ version = "2.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3728d817d99e5ac407411fa471ff9800a778d88a24685968b36824eaf4bee400"
[[package]]
name = "miniz_oxide"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be0f75932c1f6cfae3c04000e40114adf955636e19040f9c0a2c380702aa1c7f"
dependencies = [
"adler",
]
[[package]]
name = "mio"
version = "0.6.22"
@ -304,6 +340,7 @@ version = "0.1.0"
dependencies = [
"async-trait",
"clap",
"flate2",
"serde",
"serde_json",
"tokio",

View File

@ -12,6 +12,7 @@ publish = false
[dependencies]
async-trait = "0.1.36"
clap = { version = "2.33.1", features = ["yaml"] }
flate2 = "1.0.16"
serde = { version = "1.0.114", features = ["derive"] }
serde_json = "1.0.56"
tokio = { version = "0.2.22", features = ["io-util", "macros", "net", "tcp", "rt-threaded"] }

View File

@ -5,10 +5,10 @@ mod net;
use crate::net::chat::Chat;
use crate::net::connection::Connection;
use crate::net::state::handshake::Handshake;
use crate::net::state::login::Login;
use crate::net::state::play::Play;
use crate::net::state::status::Status;
use crate::net::protocol::state::handshake::Handshake;
use crate::net::protocol::state::login::Login;
use crate::net::protocol::state::play::Play;
use crate::net::protocol::state::status::Status;
use tokio::net::{TcpListener, TcpStream};
use std::io;
use std::net::IpAddr;
@ -61,7 +61,7 @@ fn mk_err<A, S: std::borrow::Borrow<str>>(str: S) -> io::Result<A> {
}
async fn interact_handshake(mut con: Connection<Handshake>) -> io::Result<()> {
use crate::net::state::handshake::*;
use crate::net::protocol::state::handshake::*;
match con.read().await? {
Serverbound::HandshakePkt(handshake) => {
@ -76,7 +76,7 @@ async fn interact_handshake(mut con: Connection<Handshake>) -> io::Result<()> {
}
async fn interact_status(mut con: Connection<Status>) -> io::Result<()> {
use crate::net::state::status::*;
use crate::net::protocol::state::status::*;
loop {
match con.read().await? {
@ -110,7 +110,7 @@ async fn interact_status(mut con: Connection<Status>) -> io::Result<()> {
}
async fn interact_login(mut con: Connection<Login>) -> io::Result<()> {
use crate::net::state::login::*;
use crate::net::protocol::state::login::*;
let name = match con.read().await? {
Serverbound::LoginStart(login_start) => {
@ -126,6 +126,8 @@ async fn interact_login(mut con: Connection<Login>) -> io::Result<()> {
eprintln!("Client set username to {}.", name);
con.set_compression(Some(64)).await?;
con.write(&Clientbound::LoginSuccess(LoginSuccess {
uuid: uuid::Uuid::nil(),
username: name,
@ -135,7 +137,7 @@ async fn interact_login(mut con: Connection<Login>) -> io::Result<()> {
}
async fn interact_play(mut con: Connection<Play>) -> io::Result<()> {
use crate::net::state::play::*;
use crate::net::protocol::state::play::*;
con.write(&Clientbound::Disconnect(Disconnect {
reason: Chat { text: "Goodbye!".to_string() }

View File

@ -1,10 +1,14 @@
mod packet_format;
use crate::net::{Reader, Writer};
use crate::net::format::{PacketFormat, DefaultPacketFormat};
use crate::net::state::ProtocolState;
use crate::net::state::handshake::Handshake;
use crate::net::state::login::Login;
use crate::net::state::play::Play;
use crate::net::state::status::Status;
use crate::net::connection::packet_format::PacketFormat;
use crate::net::connection::packet_format::default::DefaultPacketFormat;
use crate::net::protocol::packet_map::PacketMap;
use crate::net::protocol::state::ProtocolState;
use crate::net::protocol::state::handshake::Handshake;
use crate::net::protocol::state::login::Login;
use crate::net::protocol::state::play::Play;
use crate::net::protocol::state::status::Status;
use std::io;
use std::marker::PhantomData;
use tokio::net::TcpStream;
@ -12,25 +16,39 @@ use tokio::net::TcpStream;
pub struct Connection<St: ProtocolState> {
src: Reader,
dest: Writer,
fmt: Box<dyn PacketFormat>,
st: PhantomData<St>,
}
impl<St: ProtocolState> Connection<St> {
pub async fn write(&mut self, pkt: &St::Clientbound) -> io::Result<()> {
DefaultPacketFormat.send::<St::Clientbound>(&mut self.dest, pkt).await
let mut buf = Vec::new();
pkt.write(&mut buf);
self.fmt.send(&mut self.dest, buf.as_ref()).await
}
pub async fn read(&mut self) -> io::Result<St::Serverbound> {
DefaultPacketFormat.recieve::<St::Serverbound>(&mut self.src).await
use crate::net::serialize::VecPacketDeserializer;
let buf = self.fmt.recieve(&mut self.src).await?;
St::Serverbound::read(&mut VecPacketDeserializer::new(buf.as_ref()))
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))
}
pub fn into_disconnected(self) -> Connection<!> {
fn into_state<NewSt: ProtocolState>(self) -> Connection<NewSt> {
Connection {
src: self.src,
dest: self.dest,
fmt: self.fmt,
st: PhantomData,
}
}
pub fn into_disconnected(self) -> Connection<!> {
self.into_state()
}
}
impl Connection<Handshake> {
@ -42,33 +60,47 @@ impl Connection<Handshake> {
Connection {
src: BufReader::new(src),
dest: BufWriter::new(dest),
fmt: Box::new(DefaultPacketFormat),
st: PhantomData,
}
}
pub fn into_status(self) -> Connection<Status> {
Connection {
src: self.src,
dest: self.dest,
st: PhantomData,
}
self.into_state()
}
pub fn into_login(self) -> Connection<Login> {
Connection {
src: self.src,
dest: self.dest,
st: PhantomData,
}
self.into_state()
}
}
impl Connection<Login> {
pub fn into_play(self) -> Connection<Play> {
Connection {
src: self.src,
dest: self.dest,
st: PhantomData,
pub async fn set_compression(&mut self, threshold: Option<u32>) -> io::Result<()> {
use crate::net::connection::packet_format::compressed::CompressedPacketFormat;
use crate::net::serialize::VarInt;
use crate::net::protocol::state::login::{Clientbound, SetCompression};
// Tell the client about the new compression threshold,
// using a packet compressed with the old compression threshold.
self.write(&Clientbound::SetCompression(SetCompression {
// A negative threshold will disable compression.
threshold: VarInt(threshold.map(|x| x as i32).unwrap_or(-1)),
})).await?;
// Further packets will use the new compression threshold.
match threshold {
Some(threshold) => {
self.fmt = Box::new(CompressedPacketFormat { threshold: threshold as usize });
},
None => {
self.fmt = Box::new(DefaultPacketFormat);
}
}
Ok(())
}
pub fn into_play(self) -> Connection<Play> {
self.into_state()
}
}

View File

@ -0,0 +1,33 @@
pub mod compressed;
pub mod default;
use async_trait::async_trait;
use crate::net::{Reader, Writer};
use std::io;
use tokio::io::AsyncReadExt;
#[async_trait]
pub trait PacketFormat: Send + Sync {
async fn send(&self, dest: &mut Writer, data: &[u8]) -> io::Result<()>;
async fn recieve(&self, src: &mut Reader) -> io::Result<Box<[u8]>>;
}
pub const MAX_CLIENT_PACKET_SIZE: usize = 32767;
pub async fn read_varint(src: &mut Reader) -> io::Result<(usize, i32)> {
let mut length = 1;
let mut acc = 0;
while length <= 5 {
let byte = src.read_u8().await?;
acc |= (byte & 0b01111111) as i32;
if byte & 0b10000000 == 0 {
return Ok((length, acc));
}
acc = acc << 7;
length += 1;
}
Err(io::Error::new(io::ErrorKind::Other, "VarInt was too long.".to_string()))
}

View File

@ -0,0 +1,101 @@
use async_trait::async_trait;
use crate::net::{Reader, Writer};
use crate::net::connection::packet_format::{PacketFormat, MAX_CLIENT_PACKET_SIZE, read_varint};
use std::boxed::Box;
use std::io;
use tokio::io::AsyncReadExt;
pub struct CompressedPacketFormat {
pub threshold: usize,
}
#[async_trait]
impl PacketFormat for CompressedPacketFormat {
async fn send(&self, dest: &mut Writer, uncompressed_data: &[u8]) -> io::Result<()> {
use crate::net::serialize::{PacketSerializer, VarInt};
// If the length of the uncompressed packet is less than the threshold,
// then we do not compress the packet and set the data_length field to 0.
// Otherwise, data_length is set to the length of the uncompressed packet.
let will_compress = uncompressed_data.len() >= self.threshold;
let data_length = if will_compress { uncompressed_data.len() } else { 0 };
let mut data_length_buf = Vec::with_capacity(5);
data_length_buf.write(VarInt(data_length as i32));
let mut compression_buf;
let data = if will_compress {
use flate2::{Compress, Compression, FlushCompress};
// 1024 is just an arbitrary amount of extra space reserved
// in case the output data ends up larger than the input data
// (e.g. due to the zlib header).
// FIXME: Further research to figure out the exact maximum capacity necessary.
// Perhaps you only need space for the header and the data itself can't get bigger?
// And what is the limit to how much bigger the data will get?
// Currently I don't actually know for a fact that this won't ever drop data.
compression_buf = Vec::with_capacity(1024 + uncompressed_data.len());
Compress::new(Compression::best(), true)
.compress_vec(uncompressed_data, &mut compression_buf, FlushCompress::Finish)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
compression_buf.as_slice()
} else {
uncompressed_data
};
let mut packet_length_buf = Vec::with_capacity(5);
packet_length_buf.write(VarInt((data_length_buf.len() + data.len()) as i32));
{
use tokio::io::AsyncWriteExt;
dest.write(packet_length_buf.as_slice()).await?;
dest.write(data_length_buf.as_slice()).await?;
dest.write(data).await?;
dest.flush().await?;
}
Ok(())
}
async fn recieve(&self, src: &mut Reader) -> io::Result<Box<[u8]>> {
let (_, packet_length) = read_varint(src).await?;
if packet_length < 0 {
return Err(io::Error::new(io::ErrorKind::Other, "Packet length was negative."));
}
if packet_length > MAX_CLIENT_PACKET_SIZE as i32 {
return Err(io::Error::new(io::ErrorKind::Other, "Packet was too long."));
}
let packet_length = packet_length as usize;
let (data_length_size, data_length) = read_varint(src).await?;
if data_length < 0 {
return Err(io::Error::new(io::ErrorKind::Other, "Data length was negative."));
}
if data_length > MAX_CLIENT_PACKET_SIZE as i32 {
return Err(io::Error::new(io::ErrorKind::Other, "Data was too long."));
}
let data_length = data_length as usize;
let mut buf = Vec::with_capacity(packet_length);
buf.resize(packet_length, 0);
src.read_exact(buf.as_mut_slice()).await?;
let decompressed_buf = if data_length != 0 {
use flate2::{Decompress, FlushDecompress};
let mut decompressed_buf = Vec::with_capacity(data_length);
decompressed_buf.resize(data_length, 0);
Decompress::new(true)
.decompress(&buf, decompressed_buf.as_mut_slice(), FlushDecompress::Finish)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
decompressed_buf
} else {
let mut decompressed_buf = Vec::with_capacity(packet_length - data_length_size);
decompressed_buf.copy_from_slice(&buf[data_length_size..]);
decompressed_buf
};
Ok(decompressed_buf.into_boxed_slice())
}
}

View File

@ -0,0 +1,42 @@
use async_trait::async_trait;
use crate::net::{Reader, Writer};
use crate::net::connection::packet_format::{PacketFormat, MAX_CLIENT_PACKET_SIZE, read_varint};
use std::boxed::Box;
use std::io;
use tokio::io::AsyncReadExt;
pub struct DefaultPacketFormat;
#[async_trait]
impl PacketFormat for DefaultPacketFormat {
async fn send(&self, dest: &mut Writer, data: &[u8]) -> io::Result<()> {
use crate::net::serialize::{PacketSerializer, VarInt};
let mut packet_length_buf = Vec::with_capacity(5);
packet_length_buf.write(VarInt(data.len() as i32));
{
use tokio::io::AsyncWriteExt;
dest.write(packet_length_buf.as_slice()).await?;
dest.write(data).await?;
dest.flush().await?;
}
Ok(())
}
async fn recieve(&self, src: &mut Reader) -> io::Result<Box<[u8]>> {
let (_, length) = read_varint(src).await?;
if length > MAX_CLIENT_PACKET_SIZE as i32 {
return Err(io::Error::new(io::ErrorKind::Other, "Packet was too long.".to_string()));
}
let length = length as usize;
let mut buf = Vec::with_capacity(length);
buf.resize(length, 0);
src.read_exact(buf.as_mut_slice()).await?;
Ok(buf.into_boxed_slice())
}
}

View File

@ -1,82 +0,0 @@
use async_trait::async_trait;
use crate::net::{Reader, Writer};
use crate::net::packet_map::PacketMap;
use std::boxed::Box;
use std::io;
use tokio::io::AsyncReadExt;
#[async_trait]
pub trait PacketFormat {
async fn send<P: PacketMap>(&self, dest: &mut Writer, pkt: &P) -> io::Result<()>;
async fn recieve<P: PacketMap>(&self, src: &mut Reader) -> io::Result<P>;
}
pub const MAX_CLIENT_PACKET_SIZE: usize = 32767;
pub struct DefaultPacketFormat;
async fn read_varint(src: &mut Reader) -> io::Result<(usize, i32)> {
let mut length = 1;
let mut acc = 0;
while length <= 5 {
let byte = src.read_u8().await?;
acc |= (byte & 0b01111111) as i32;
if byte & 0b10000000 == 0 {
return Ok((length, acc));
}
acc = acc << 7;
length += 1;
}
Err(io::Error::new(io::ErrorKind::Other, "VarInt was too long."))
}
#[async_trait]
impl PacketFormat for DefaultPacketFormat {
async fn send<P: PacketMap>(&self, dest: &mut Writer, pkt: &P) -> io::Result<()> {
use crate::net::serialize::{PacketSerializer, VarInt};
let packet_id = pkt.id();
let mut data = Vec::new();
pkt.write(&mut data);
let mut packet_id_buf = Vec::with_capacity(5);
packet_id_buf.write(VarInt(packet_id));
let packet_length = packet_id_buf.len() + data.len();
let mut packet_length_buf = Vec::with_capacity(5);
packet_length_buf.write(VarInt(packet_length as i32));
{
use tokio::io::AsyncWriteExt;
dest.write(packet_length_buf.as_slice()).await?;
dest.write(packet_id_buf.as_slice()).await?;
dest.write(data.as_slice()).await?;
dest.flush().await?;
}
Ok(())
}
async fn recieve<P: PacketMap>(&self, src: &mut Reader) -> io::Result<P> {
use crate::net::serialize::VecPacketDeserializer;
let (_, length) = read_varint(src).await?;
if length > MAX_CLIENT_PACKET_SIZE as i32 {
return Err(io::Error::new(io::ErrorKind::Other, "Packet was too long."));
}
let length = length as usize;
let (id_length, id) = read_varint(src).await?;
let mut buf = Vec::with_capacity(length - id_length);
buf.resize(length - id_length, 0);
src.read_exact(buf.as_mut_slice()).await?;
P::read(id, &mut VecPacketDeserializer::new(&buf))
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))
}
}

View File

@ -1,10 +1,7 @@
pub mod chat;
pub mod connection;
pub mod format;
pub mod packet;
pub mod packet_map;
pub mod protocol;
pub mod serialize;
pub mod state;
use tokio::io::{BufReader, BufWriter};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};

3
src/net/protocol/mod.rs Normal file
View File

@ -0,0 +1,3 @@
pub mod packet_map;
pub mod packet;
pub mod state;

View File

@ -1,20 +1,14 @@
use crate::net::serialize::{PacketDeserializer, PacketSerializer};
pub trait PacketMap: Sized + Sync {
/// Get a packet's id.
fn id(&self) -> i32;
/// Read a packet from the deserializer.
fn read(id: i32, deser: &mut impl PacketDeserializer) -> Result<Self, String>;
fn read(deser: &mut impl PacketDeserializer) -> Result<Self, String>;
/// Write this packet's data to the serializer.
fn write(&self, ser: &mut impl PacketSerializer);
}
impl PacketMap for ! {
fn id(&self) -> i32 {
match *self { }
}
fn read(_id: i32, _deser: &mut impl PacketDeserializer) -> Result<Self, String> {
fn read(_deser: &mut impl PacketDeserializer) -> Result<Self, String> {
Err("Cannot read packets; the connection state is disconnected.".to_string())
}
@ -31,16 +25,11 @@ macro_rules! define_packet_maps {
$( $packet($packet) ),*
}
impl crate::net::packet_map::PacketMap for $name {
fn id(&self) -> i32 {
match *self {
$( $name::$packet(_) => $id ),*
}
}
impl crate::net::protocol::packet_map::PacketMap for $name {
#[allow(unused_variables)]
fn read(id: i32, deser: &mut impl crate::net::serialize::PacketDeserializer)
fn read(deser: &mut impl crate::net::serialize::PacketDeserializer)
-> Result<Self, String> {
let id: i32 = deser.read::<crate::net::serialize::VarInt>()?.into();
match id {
$( $id => deser.read::<$packet>().map($name::$packet), )*
id => Err(format!("Invalid packet id: {}", id))
@ -50,7 +39,10 @@ macro_rules! define_packet_maps {
#[allow(unused_variables)]
fn write(&self, ser: &mut impl crate::net::serialize::PacketSerializer) {
match *self {
$( $name::$packet(ref pkt) => ser.write(pkt) ),*
$( $name::$packet(ref pkt) => {
ser.write(crate::net::serialize::VarInt($id));
ser.write::<&$packet>(&pkt);
} ),*
}
}
}

View File

@ -3,7 +3,7 @@ pub mod login;
pub mod play;
pub mod status;
use crate::net::packet_map::PacketMap;
use crate::net::protocol::packet_map::PacketMap;
pub trait ProtocolState {
type Clientbound: PacketMap;
@ -21,7 +21,7 @@ macro_rules! define_state {
#[allow(dead_code)]
pub enum $name {}
impl crate::net::state::ProtocolState for $name {
impl crate::net::protocol::state::ProtocolState for $name {
type Clientbound = $cb;
type Serverbound = $sb;
}

View File

@ -1,4 +1,3 @@
use crate::net::format::MAX_CLIENT_PACKET_SIZE;
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::borrow::Borrow;
@ -199,13 +198,9 @@ impl PacketReadable for Vec<u8> {
fn read(deser: &mut impl PacketDeserializer) -> Result<Self, String> {
let length: i32 = deser.read::<VarInt>()?.into();
if length < 0 {
return Err("String length cannot be negative.".to_string());
return Err("Array or string length cannot be negative.".to_string());
}
let length = length as usize;
if length > MAX_CLIENT_PACKET_SIZE {
return Err("Byte array was too long.".to_string());
}
let mut it = Vec::with_capacity(length);
it.resize(length, 0);