diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index a16207e..078a889 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -12,7 +12,7 @@ jobs: - name: Install latest nightly uses: actions-rs/toolchain@v1 with: - toolchain: stable + toolchain: nightly override: true - name: Build diff --git a/rust-toolchain b/rust-toolchain new file mode 100644 index 0000000..bf867e0 --- /dev/null +++ b/rust-toolchain @@ -0,0 +1 @@ +nightly diff --git a/src/main.rs b/src/main.rs index 6c6a6aa..a888c06 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +#![feature(const_generics)] + mod net; use crate::net::{Reader, Writer}; diff --git a/src/net/chat.rs b/src/net/chat.rs index 63310d5..9c736e6 100644 --- a/src/net/chat.rs +++ b/src/net/chat.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; // TODO: Support more features. -#[derive(Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] pub struct Chat { pub text: String, } diff --git a/src/net/format.rs b/src/net/format.rs index 78b6f2c..ac8a019 100644 --- a/src/net/format.rs +++ b/src/net/format.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use crate::net::{Reader, Writer}; use std::boxed::Box; use std::io; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::io::AsyncReadExt; #[async_trait] pub trait PacketFormat { @@ -35,19 +35,23 @@ async fn read_varint<'rr, 'r>(src: &'rr mut Reader<'r>) -> io::Result<(usize, i3 #[async_trait] impl PacketFormat for DefaultPacketFormat { async fn send<'wr, 'w>(&self, dest: &'wr mut Writer<'w>, packet_id: i32, data: &[u8]) -> io::Result<()> { - use crate::net::serialize::PacketSerializer; + use crate::net::serialize::{PacketSerializer, VarInt}; let mut packet_id_buf = Vec::with_capacity(5); - packet_id_buf.write_varint(packet_id); + packet_id_buf.write(VarInt(packet_id)); let packet_length = packet_id_buf.len() + data.len(); let mut packet_length_buf = Vec::with_capacity(5); - packet_length_buf.write_varint(packet_length as i32); + packet_length_buf.write(VarInt(packet_length as i32)); - dest.write(packet_length_buf.as_slice()).await?; - dest.write(packet_id_buf.as_slice()).await?; - dest.write(data).await?; - dest.flush().await?; + { + use tokio::io::AsyncWriteExt; + + dest.write(packet_length_buf.as_slice()).await?; + dest.write(packet_id_buf.as_slice()).await?; + dest.write(data).await?; + dest.flush().await?; + } Ok(()) } diff --git a/src/net/packet.rs b/src/net/packet.rs index 45704fc..9651198 100644 --- a/src/net/packet.rs +++ b/src/net/packet.rs @@ -1,2 +1,2 @@ pub mod handshake; -pub mod status; \ No newline at end of file +pub mod status; diff --git a/src/net/packet/handshake.rs b/src/net/packet/handshake.rs index cdf5949..c293b7e 100644 --- a/src/net/packet/handshake.rs +++ b/src/net/packet/handshake.rs @@ -1,4 +1,4 @@ -use crate::net::serialize::PacketDeserializer; +use crate::net::serialize::{PacketData, PacketDeserializer, PacketSerializer, VarInt}; #[derive(Debug, PartialEq, Eq)] pub enum HandshakeNextState { @@ -14,6 +14,32 @@ pub struct PacketHandshake { pub next_state: HandshakeNextState, } +impl PacketData for PacketHandshake { + fn read(deser: &mut impl PacketDeserializer) -> Result { + let protocol_version = deser.read::()?.into(); + let server_address = deser.read::()?; + let server_port = deser.read::()?; + let next_state = match deser.read::()?.into() { + 1 => HandshakeNextState::Status, + 2 => HandshakeNextState::Login, + n => return Err(format!("Invalid next protocol state in handshake: {}", n)) + }; + deser.read_eof()?; + Ok(PacketHandshake { + protocol_version: protocol_version, + server_address: server_address, + server_port: server_port, + next_state: next_state, + }) + } + + fn write(&self, ser: &mut impl PacketSerializer) { + ser.write(self.protocol_version); + ser.write(self.server_address.clone()); + ser.write(self.server_port); + } +} + #[derive(Debug)] pub enum PacketHandshakeServerbound { Handshake(PacketHandshake), @@ -24,23 +50,7 @@ pub fn read_packet_handshake(id: i32, deser: &mut impl PacketDeserializer) use PacketHandshakeServerbound::*; match id { - 0x00 => { - let protocol_version = deser.read_varint()?; - let server_address = deser.read_string()?; - let server_port = deser.read_u16()?; - let next_state = match deser.read_varint()? { - 1 => HandshakeNextState::Status, - 2 => HandshakeNextState::Login, - n => return Err(format!("Invalid next protocol state in handshake: {}", n)) - }; - deser.read_eof()?; - Ok(Handshake(PacketHandshake { - protocol_version: protocol_version, - server_address: server_address, - server_port: server_port, - next_state: next_state, - })) - }, + 0x00 => deser.read().map(Handshake), id => Err(format!("Invalid handshake packet id: {}", id)) } } diff --git a/src/net/packet/status.rs b/src/net/packet/status.rs index 8383d0e..bc69bfd 100644 --- a/src/net/packet/status.rs +++ b/src/net/packet/status.rs @@ -1,29 +1,31 @@ use crate::net::chat::Chat; -use crate::net::serialize::{PacketSerializer, PacketDeserializer}; -use serde::Serialize; -use std::convert::TryInto; +use crate::net::serialize::{PacketDeserializer, PacketSerializer, PacketJson}; +use serde::{Deserialize, Serialize}; use uuid::Uuid; -#[derive(Serialize)] +#[derive(Clone, Deserialize, Serialize)] pub struct PacketResponseVersion { pub name: String, pub protocol: u32, } +impl PacketJson for PacketResponseVersion {} -#[derive(Serialize)] +#[derive(Clone, Deserialize, Serialize)] pub struct PacketResponsePlayersSample { pub name: String, pub id: Uuid, } +impl PacketJson for PacketResponsePlayersSample {} -#[derive(Serialize)] +#[derive(Clone, Deserialize, Serialize)] pub struct PacketResponsePlayers { pub max: u32, pub online: u32, pub sample: Vec } +impl PacketJson for PacketResponsePlayers {} -#[derive(Serialize)] +#[derive(Clone, Deserialize, Serialize)] pub struct PacketResponse { pub version: PacketResponseVersion, pub players: PacketResponsePlayers, @@ -31,6 +33,7 @@ pub struct PacketResponse { #[serde(skip_serializing_if = "Option::is_none")] pub favicon: Option, } +impl PacketJson for PacketResponse {} pub enum PacketStatusClientbound { Response(PacketResponse), @@ -53,10 +56,9 @@ pub fn read_packet_status(id: i32, deser: &mut impl PacketDeserializer) Ok(Request) }, 1 => { - let mut buf = [0; 8]; - deser.read(&mut buf)?; + let payload = deser.read::<[u8; 8]>()?; deser.read_eof()?; - Ok(Ping(buf.try_into().unwrap())) + Ok(Ping(payload)) } id => Err(format!("Invalid status packet id: {}", id)) } @@ -68,11 +70,11 @@ pub fn write_packet_status(ser: &mut impl PacketSerializer, packet: PacketStatus match packet { Response(response) => { - ser.write_json(&response); + ser.write(response); 0x00 }, Pong(payload) => { - ser.write(&payload); + ser.write(payload); 0x01 } } diff --git a/src/net/serialize.rs b/src/net/serialize.rs index 0c2687c..f351648 100644 --- a/src/net/serialize.rs +++ b/src/net/serialize.rs @@ -1,85 +1,147 @@ use crate::net::format::MAX_CLIENT_PACKET_SIZE; use serde::Serialize; +use serde::de::DeserializeOwned; +use std::convert::{From, Into}; -pub trait PacketSerializer { - /// Write a slice of bytes directly, without a length prefix. - fn write(&mut self, value: &[u8]); +pub trait PacketData: Sized { + fn read(deser: &mut impl PacketDeserializer) -> Result + where Self: std::marker::Sized; + fn write(&self, ser: &mut impl PacketSerializer); +} - fn write_u8(&mut self, value: u8) { - self.write(&value.to_le_bytes()); +impl PacketData for [u8; N] { + fn read(deser: &mut impl PacketDeserializer) -> Result { + use std::convert::TryInto; + + let mut buf = [0; N]; + deser.read_exact(&mut buf)?; + Ok(buf.try_into().unwrap()) } - /// https://wiki.vg/Protocol#VarInt_and_VarLong - fn write_varint(&mut self, mut value: i32) { - loop { - let mut temp = (value & 0b01111111) as u8; - value = value >> 7; - if value != 0 { - temp |= 0b10000000; - } - self.write_u8(temp); + fn write(&self, ser: &mut impl PacketSerializer) { + ser.write_exact(self); + } +} - if value == 0 { - break; - } +impl PacketData for bool { + fn read(deser: &mut impl PacketDeserializer) -> Result { + let value = deser.read::()?; + match value { + 0x00 => Ok(false), + 0x01 => Ok(true), + n => Err(format!("{:0X} is not a valid boolean.", n)) } } - /// Write a varint-length-prefixed byte slice. - fn write_slice(&mut self, value: &[u8]) { - self.write_varint(value.len() as i32); - self.write(value); - } - - /// Write a varint-length-prefixed string. - fn write_str(&mut self, value: &str) { - self.write_slice(value.as_bytes()); - } - - fn write_json(&mut self, value: &impl Serialize) { - self.write_slice(serde_json::to_vec(value).unwrap().as_slice()) + fn write(&self, ser: &mut impl PacketSerializer) { + ser.write(*self as u8); } } -impl PacketSerializer for Vec { - fn write(&mut self, value: &[u8]) { - self.extend_from_slice(value); +impl PacketData for u8 { + fn read(deser: &mut impl PacketDeserializer) -> Result { + deser.read::<[u8; 1]>().map(u8::from_be_bytes) } - fn write_u8(&mut self, value: u8) { - self.push(value); + fn write(&self, ser: &mut impl PacketSerializer) { + ser.write(self.to_be_bytes()) } } -pub trait PacketDeserializer { - fn read(&mut self, buf: &mut [u8]) -> Result<(), String>; - fn read_eof(&mut self) -> Result<(), String>; - - fn read_u8(&mut self) -> Result { - let mut buf = [0; 1]; - self.read(&mut buf)?; - Ok(u8::from_le_bytes(buf)) +impl PacketData for i8 { + fn read(deser: &mut impl PacketDeserializer) -> Result { + deser.read::<[u8; 1]>().map(i8::from_be_bytes) } - fn read_u16(&mut self) -> Result { - let mut buf = [0; 2]; - self.read(&mut buf)?; - Ok(u16::from_le_bytes(buf)) + fn write(&self, ser: &mut impl PacketSerializer) { + ser.write(self.to_be_bytes()) + } +} + +impl PacketData for u16 { + fn read(deser: &mut impl PacketDeserializer) -> Result { + deser.read::<[u8; 2]>().map(u16::from_be_bytes) } - fn read_varint(&mut self) -> Result { + fn write(&self, ser: &mut impl PacketSerializer) { + ser.write(self.to_be_bytes()); + } +} + +impl PacketData for i16 { + fn read(deser: &mut impl PacketDeserializer) -> Result { + deser.read::<[u8; 2]>().map(i16::from_be_bytes) + } + + fn write(&self, ser: &mut impl PacketSerializer) { + ser.write(self.to_be_bytes()); + } +} + +impl PacketData for i32 { + fn read(deser: &mut impl PacketDeserializer) -> Result { + deser.read::<[u8; 4]>().map(i32::from_be_bytes) + } + + fn write(&self, ser: &mut impl PacketSerializer) { + ser.write(self.to_be_bytes()); + } +} + +impl PacketData for i64 { + fn read(deser: &mut impl PacketDeserializer) -> Result { + deser.read::<[u8; 8]>().map(i64::from_be_bytes) + } + + fn write(&self, ser: &mut impl PacketSerializer) { + ser.write(self.to_be_bytes()); + } +} + +impl PacketData for f32 { + fn read(deser: &mut impl PacketDeserializer) -> Result { + deser.read::<[u8; 4]>().map(f32::from_be_bytes) + } + + fn write(&self, ser: &mut impl PacketSerializer) { + ser.write(self.to_be_bytes()); + } +} + +impl PacketData for f64 { + fn read(deser: &mut impl PacketDeserializer) -> Result { + deser.read::<[u8; 8]>().map(f64::from_be_bytes) + } + + fn write(&self, ser: &mut impl PacketSerializer) { + ser.write(self.to_be_bytes()); + } +} + +pub struct VarInt(pub i32); + +impl From for VarInt { + fn from(x: i32) -> Self { VarInt(x) } +} + +impl Into for VarInt { + fn into(self) -> i32 { self.0 } +} + +impl PacketData for VarInt { + fn read(deser: &mut impl PacketDeserializer) -> Result { let mut length = 1; let mut acc = 0; // VarInts must not be longer than 5 bytes. while length <= 5 { // If the highest bit is set, there are further bytes to be read from this VarInt; // the rest of the bits are the actual data in the VarInt. - let read = self.read_u8()?; + let read = deser.read::()?; acc |= (read & 0b01111111) as i32; - // There are no mo + // There are no more bytes. if (read & 0b10000000) == 0 { - return Ok(acc); + return Ok(VarInt(acc)); } // Make space for the rest of the bits. @@ -91,21 +153,147 @@ pub trait PacketDeserializer { Err("VarInt was more than 5 bytes.".to_string()) } - fn read_string(&mut self) -> Result { - let length = self.read_varint()?; + fn write(&self, ser: &mut impl PacketSerializer) { + let mut value = self.0; + loop { + let mut temp = (value & 0b01111111) as u8; + value = value >> 7; + if value != 0 { + temp |= 0b10000000; + } + ser.write(temp); + + if value == 0 { + break; + } + } + } +} + +pub struct VarLong(pub i64); + +impl From for VarLong { + fn from(x: i64) -> Self { VarLong(x) } +} + +impl Into for VarLong { + fn into(self) -> i64 { self.0 } +} + +impl PacketData for VarLong { + fn read(deser: &mut impl PacketDeserializer) -> Result { + let mut length = 1; + let mut acc = 0; + // VarLongs must not be longer than 5 bytes. + while length <= 10 { + // If the highest bit is set, there are further bytes to be read from this VarLong; + // the rest of the bits are the actual data in the VarLong. + let read = deser.read::()?; + acc |= (read & 0b01111111) as i64; + + // There are no more bytes. + if (read & 0b10000000) == 0 { + return Ok(VarLong(acc)); + } + + // Make space for the rest of the bits. + acc = acc << 7; + length += 1; + } + + // The VarLong was too long! + Err("VarLong was more than 10 bytes.".to_string()) + } + + fn write(&self, ser: &mut impl PacketSerializer) { + let mut value = self.0; + loop { + let mut temp = (value & 0b01111111) as u8; + value = value >> 7; + if value != 0 { + temp |= 0b10000000; + } + ser.write(temp); + + if value == 0 { + break; + } + } + } +} + +impl PacketData for Vec { + fn read(deser: &mut impl PacketDeserializer) -> Result { + let length: i32 = deser.read::()?.into(); if length < 0 { return Err("String length cannot be negative.".to_string()); } let length = length as usize; if length > MAX_CLIENT_PACKET_SIZE { - return Err("String was too long.".to_string()); + return Err("Byte array was too long.".to_string()); } - let mut buf = Vec::with_capacity(length); - buf.resize(length, 0); - self.read(buf.as_mut_slice())?; - String::from_utf8(buf).map_err(|_| "String was invalid UTF-8.".to_string()) + let mut it = Vec::with_capacity(length); + it.resize(length, 0); + deser.read_exact(it.as_mut_slice())?; + + Ok(it) + } + + fn write(&self, ser: &mut impl PacketSerializer) { + ser.write(VarInt(self.len() as i32)); + ser.write_exact(self.as_slice()); + } +} + +impl PacketData for String { + fn read(deser: &mut impl PacketDeserializer) -> Result { + let bytes = deser.read()?; + String::from_utf8(bytes).map_err(|_| "String contained invalid UTF-8.".to_string()) + } + + fn write(&self, ser: &mut impl PacketSerializer) { + ser.write(self.clone().into_bytes()); + } +} + +pub trait PacketJson: DeserializeOwned + Serialize + Sized { } + +impl PacketJson for crate::net::chat::Chat { } + +impl PacketData for S { + fn read(deser: &mut impl PacketDeserializer) -> Result { + let bytes = deser.read::>()?; + serde_json::from_slice(&bytes).map_err(|_| "Bad JSON syntax".to_string()) + } + + fn write(&self, ser: &mut impl PacketSerializer) { + ser.write(serde_json::to_vec(self).unwrap()); + } +} + +pub trait PacketSerializer: Sized { + /// Write a slice of bytes directly, without a length prefix. + fn write_exact(&mut self, value: &[u8]); + + fn write(&mut self, value: D) { + value.write(self) + } +} + +impl PacketSerializer for Vec { + fn write_exact(&mut self, value: &[u8]) { + self.extend_from_slice(value); + } +} + +pub trait PacketDeserializer: Sized { + fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), String>; + fn read_eof(&mut self) -> Result<(), String>; + + fn read(&mut self) -> Result { + D::read(self) } } @@ -124,7 +312,7 @@ impl VecPacketDeserializer<'_> { } impl PacketDeserializer for VecPacketDeserializer<'_> { - fn read(&mut self, buf: &mut [u8]) -> Result<(), String> { + fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), String> { if self.index + buf.len() > self.data.len() { return Err("Tried to read past length of packet.".to_string()); }