Begin work on improving network packet abstractions.
* Packet header/stream stuff is now defined by a PacketFormat. * Actual packet serialization/deserialization is handled by PacketSerializer/PacketDeserializer. * The end API is still awkaward, so more work is needed.master
parent
9ee0dbe63e
commit
d32118db4f
|
@ -9,6 +9,17 @@ dependencies = [
|
|||
"winapi 0.3.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.36"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a265e3abeffdce30b2e26b7a11b222fe37c6067404001b434101457d0385eb92"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atty"
|
||||
version = "0.2.14"
|
||||
|
@ -291,6 +302,7 @@ dependencies = [
|
|||
name = "tmd"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"clap",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
|
@ -10,6 +10,7 @@ publish = false
|
|||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.36"
|
||||
clap = { version = "2.33.1", features = ["yaml"] }
|
||||
serde = { version = "1.0.114", features = ["derive"] }
|
||||
serde_json = "1.0.56"
|
||||
|
|
93
src/main.rs
93
src/main.rs
|
@ -1,16 +1,14 @@
|
|||
mod net;
|
||||
|
||||
use crate::net::{Reader, Writer};
|
||||
use crate::net::chat::Chat;
|
||||
use crate::net::source;
|
||||
use crate::net::source::{PacketError, PacketSource};
|
||||
use tokio::io::BufWriter;
|
||||
use crate::net::format::{PacketFormat, DefaultPacketFormat};
|
||||
use crate::net::serialize::VecPacketDeserializer;
|
||||
use tokio::io::{BufReader, BufWriter};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::net::tcp::WriteHalf;
|
||||
use std::io;
|
||||
use std::net::IpAddr;
|
||||
|
||||
use PacketError::*;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
let yaml = clap::load_yaml!("cli.yml");
|
||||
|
@ -43,8 +41,8 @@ async fn listen(mut listener: TcpListener) {
|
|||
}
|
||||
|
||||
async fn accept_connection(mut socket: TcpStream) {
|
||||
let (mut read, write) = socket.split();
|
||||
let mut source = PacketSource::new(&mut read);
|
||||
let (read, write) = socket.split();
|
||||
let mut source = BufReader::new(read);
|
||||
let mut dest = BufWriter::new(write);
|
||||
|
||||
eprintln!("Client connected.");
|
||||
|
@ -54,53 +52,66 @@ async fn accept_connection(mut socket: TcpStream) {
|
|||
}
|
||||
}
|
||||
|
||||
async fn interact_handshake(source: &mut PacketSource<'_>, dest: &mut BufWriter<WriteHalf<'_>>) -> source::Result<()> {
|
||||
async fn interact_handshake<'a>(source: &mut Reader<'a>, dest: &mut Writer<'a>) -> io::Result<()> {
|
||||
use crate::net::packet::handshake::*;
|
||||
use PacketHandshakeServerbound::*;
|
||||
|
||||
match read_packet_handshake(source).await? {
|
||||
Handshake(pkt) => {
|
||||
if pkt.next_state == HandshakeNextState::Status {
|
||||
interact_status(source, dest).await
|
||||
} else {
|
||||
Err(PktError("We do not support client log-in yet.".to_string()))
|
||||
let (id, data) = DefaultPacketFormat.recieve(source).await?;
|
||||
let mut deser = VecPacketDeserializer::new(&data);
|
||||
|
||||
match read_packet_handshake(id, &mut deser)
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err)) {
|
||||
Ok(pkt) => match pkt {
|
||||
Handshake(pkt) => {
|
||||
if pkt.next_state == HandshakeNextState::Status {
|
||||
interact_status(source, dest).await
|
||||
} else {
|
||||
Err(io::Error::new(io::ErrorKind::Other, "We do not support client log-in yet.".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(err) => Err(io::Error::new(io::ErrorKind::Other, err))
|
||||
}
|
||||
}
|
||||
|
||||
async fn interact_status(source: &mut PacketSource<'_>, dest: &mut BufWriter<WriteHalf<'_>>) -> source::Result<()> {
|
||||
async fn interact_status<'a>(source: &mut Reader<'a>, dest: &mut Writer<'a>) -> io::Result<()> {
|
||||
use crate::net::packet::status::*;
|
||||
use PacketStatusClientbound::*;
|
||||
use PacketStatusServerbound::*;
|
||||
|
||||
loop {
|
||||
match read_packet_status(source).await? {
|
||||
Request => {
|
||||
match write_packet_status(dest, Response(PacketResponse {
|
||||
version: PacketResponseVersion {
|
||||
name: "1.16.1".to_string(),
|
||||
protocol: 736,
|
||||
},
|
||||
players: PacketResponsePlayers {
|
||||
max: 255,
|
||||
online: 0,
|
||||
sample: Vec::new(),
|
||||
},
|
||||
description: Chat { text: "Hello, world!".to_string() },
|
||||
favicon: None,
|
||||
})).await {
|
||||
Ok(_) => {},
|
||||
Err(err) => return Err(IoError(err))
|
||||
let (id, data) = DefaultPacketFormat.recieve(source).await?;
|
||||
let mut deser = VecPacketDeserializer::new(&data);
|
||||
|
||||
match read_packet_status(id, &mut deser)
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err)) {
|
||||
Ok(pkt) => match pkt {
|
||||
Request => {
|
||||
let mut buf = Vec::new();
|
||||
let id = write_packet_status(&mut buf, Response(PacketResponse {
|
||||
version: PacketResponseVersion {
|
||||
name: "1.16.1".to_string(),
|
||||
protocol: 736,
|
||||
},
|
||||
players: PacketResponsePlayers {
|
||||
max: 255,
|
||||
online: 0,
|
||||
sample: Vec::new(),
|
||||
},
|
||||
description: Chat { text: "Hello, world!".to_string() },
|
||||
favicon: None,
|
||||
}));
|
||||
DefaultPacketFormat.send(dest, id, buf.as_slice()).await?;
|
||||
},
|
||||
Ping(payload) => {
|
||||
let mut buf = Vec::new();
|
||||
let id = write_packet_status(&mut buf, Pong(payload));
|
||||
DefaultPacketFormat.send(dest, id, buf.as_slice()).await?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
},
|
||||
Ping(payload) => {
|
||||
match write_packet_status(dest, Pong(payload)).await {
|
||||
Ok(_) => {},
|
||||
Err(err) => return Err(IoError(err))
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
Err(err) => return Err(io::Error::new(io::ErrorKind::Other, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
use async_trait::async_trait;
|
||||
use crate::net::{Reader, Writer};
|
||||
use std::boxed::Box;
|
||||
use std::io;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
#[async_trait]
|
||||
pub trait PacketFormat {
|
||||
async fn send<'wr, 'w>(&self, dest: &'wr mut Writer<'w>, packet_id: i32, data: &[u8]) -> io::Result<()>;
|
||||
async fn recieve<'rr, 'r>(&self, src: &'rr mut Reader<'r>) -> io::Result<(i32, Box<[u8]>)>;
|
||||
}
|
||||
|
||||
pub const MAX_CLIENT_PACKET_SIZE: usize = 32767;
|
||||
|
||||
pub struct DefaultPacketFormat;
|
||||
|
||||
async fn read_varint<'rr, 'r>(src: &'rr mut Reader<'r>) -> io::Result<(usize, i32)> {
|
||||
let mut length = 1;
|
||||
let mut acc = 0;
|
||||
while length <= 5 {
|
||||
let byte = src.read_u8().await?;
|
||||
acc |= (byte & 0b01111111) as i32;
|
||||
|
||||
if byte & 0b10000000 == 0 {
|
||||
return Ok((length, acc));
|
||||
}
|
||||
|
||||
acc = acc << 7;
|
||||
length += 1;
|
||||
}
|
||||
|
||||
Err(io::Error::new(io::ErrorKind::Other, "VarInt was too long."))
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl PacketFormat for DefaultPacketFormat {
|
||||
async fn send<'wr, 'w>(&self, dest: &'wr mut Writer<'w>, packet_id: i32, data: &[u8]) -> io::Result<()> {
|
||||
use crate::net::serialize::PacketSerializer;
|
||||
|
||||
let mut packet_id_buf = Vec::with_capacity(5);
|
||||
packet_id_buf.write_varint(packet_id);
|
||||
|
||||
let packet_length = packet_id_buf.len() + data.len();
|
||||
let mut packet_length_buf = Vec::with_capacity(5);
|
||||
packet_length_buf.write_varint(packet_length as i32);
|
||||
|
||||
dest.write(packet_length_buf.as_slice()).await?;
|
||||
dest.write(packet_id_buf.as_slice()).await?;
|
||||
dest.write(data).await?;
|
||||
dest.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recieve<'rr, 'r>(&self, src: &'rr mut Reader<'r>) -> io::Result<(i32, Box<[u8]>)> {
|
||||
let (_, length) = read_varint(src).await?;
|
||||
if length > MAX_CLIENT_PACKET_SIZE as i32 {
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "Packet was too long."));
|
||||
}
|
||||
let length = length as usize;
|
||||
|
||||
let (id_length, id) = read_varint(src).await?;
|
||||
|
||||
let mut buf = Vec::with_capacity(length - id_length);
|
||||
buf.resize(length - id_length, 0);
|
||||
src.read_exact(buf.as_mut_slice()).await?;
|
||||
|
||||
Ok((id, buf.into_boxed_slice()))
|
||||
}
|
||||
}
|
|
@ -1,3 +1,10 @@
|
|||
pub mod chat;
|
||||
pub mod format;
|
||||
pub mod packet;
|
||||
pub mod source;
|
||||
pub mod serialize;
|
||||
|
||||
use tokio::io::{BufReader, BufWriter};
|
||||
use tokio::net::tcp::{ReadHalf, WriteHalf};
|
||||
|
||||
pub type Reader<'a> = BufReader<ReadHalf<'a>>;
|
||||
pub type Writer<'a> = BufWriter<WriteHalf<'a>>;
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use crate::net::source::{PacketError, PacketSource, Result};
|
||||
use PacketError::PktError;
|
||||
use crate::net::serialize::PacketDeserializer;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum HandshakeNextState {
|
||||
|
@ -20,21 +19,21 @@ pub enum PacketHandshakeServerbound {
|
|||
Handshake(PacketHandshake),
|
||||
}
|
||||
|
||||
pub async fn read_packet_handshake(source: &mut PacketSource<'_>) -> Result<PacketHandshakeServerbound> {
|
||||
pub fn read_packet_handshake(id: i32, deser: &mut impl PacketDeserializer)
|
||||
-> Result<PacketHandshakeServerbound, String> {
|
||||
use PacketHandshakeServerbound::*;
|
||||
|
||||
let _length = source.read_varint().await?;
|
||||
let id = source.read_varint().await?;
|
||||
match id {
|
||||
0x00 => {
|
||||
let protocol_version = source.read_varint().await?;
|
||||
let server_address = source.read_string().await?;
|
||||
let server_port = source.read_u16().await?;
|
||||
let next_state = match source.read_varint().await? {
|
||||
let protocol_version = deser.read_varint()?;
|
||||
let server_address = deser.read_string()?;
|
||||
let server_port = deser.read_u16()?;
|
||||
let next_state = match deser.read_varint()? {
|
||||
1 => HandshakeNextState::Status,
|
||||
2 => HandshakeNextState::Login,
|
||||
n => return Err(PktError(format!("Invalid next protocol state in handshake: {}", n)))
|
||||
n => return Err(format!("Invalid next protocol state in handshake: {}", n))
|
||||
};
|
||||
deser.read_eof()?;
|
||||
Ok(Handshake(PacketHandshake {
|
||||
protocol_version: protocol_version,
|
||||
server_address: server_address,
|
||||
|
@ -42,6 +41,6 @@ pub async fn read_packet_handshake(source: &mut PacketSource<'_>) -> Result<Pack
|
|||
next_state: next_state,
|
||||
}))
|
||||
},
|
||||
id => Err(PktError(format!("Invalid handshake packet id: {}", id)))
|
||||
id => Err(format!("Invalid handshake packet id: {}", id))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,14 +1,9 @@
|
|||
use crate::net::chat::Chat;
|
||||
use crate::net::source::{PacketError, PacketSource, Result};
|
||||
use crate::net::serialize::{PacketSerializer, PacketDeserializer};
|
||||
use serde::Serialize;
|
||||
use std::convert::TryInto;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufWriter;
|
||||
use tokio::net::tcp::WriteHalf;
|
||||
use uuid::Uuid;
|
||||
|
||||
use PacketError::PktError;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct PacketResponseVersion {
|
||||
pub name: String,
|
||||
|
@ -48,70 +43,37 @@ pub enum PacketStatusServerbound {
|
|||
Ping([u8; 8]),
|
||||
}
|
||||
|
||||
pub async fn read_packet_status(source: &mut PacketSource<'_>) -> Result<PacketStatusServerbound> {
|
||||
pub fn read_packet_status(id: i32, deser: &mut impl PacketDeserializer)
|
||||
-> Result<PacketStatusServerbound, String> {
|
||||
use PacketStatusServerbound::*;
|
||||
|
||||
let _length = source.read_varint().await?;
|
||||
let id = source.read_varint().await?;
|
||||
|
||||
match id {
|
||||
0 => Ok(Request),
|
||||
0 => {
|
||||
deser.read_eof()?;
|
||||
Ok(Request)
|
||||
},
|
||||
1 => {
|
||||
let mut buf = [0; 8];
|
||||
source.read_exact(&mut buf).await?;
|
||||
deser.read(&mut buf)?;
|
||||
deser.read_eof()?;
|
||||
Ok(Ping(buf.try_into().unwrap()))
|
||||
}
|
||||
id => Err(PktError(format!("Invalid status packet id: {}", id)))
|
||||
id => Err(format!("Invalid status packet id: {}", id))
|
||||
}
|
||||
}
|
||||
|
||||
fn write_varint(dest: &mut Vec<u8>, mut value: i32) {
|
||||
loop {
|
||||
let mut temp = (value & 0b01111111) as u8;
|
||||
value = value >> 7;
|
||||
if value != 0 {
|
||||
temp |= 0b10000000;
|
||||
}
|
||||
dest.push(temp);
|
||||
|
||||
if value == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn write_slice(dest: &mut Vec<u8>, bytes: &[u8]) {
|
||||
write_varint(dest, bytes.len() as i32);
|
||||
dest.extend_from_slice(bytes);
|
||||
}
|
||||
|
||||
fn write_string(dest: &mut Vec<u8>, value: &str) {
|
||||
write_slice(dest, value.as_bytes());
|
||||
}
|
||||
|
||||
pub async fn write_packet_status(dest: &mut BufWriter<WriteHalf<'_>>,
|
||||
packet: PacketStatusClientbound) -> std::io::Result<()> {
|
||||
pub fn write_packet_status(ser: &mut impl PacketSerializer, packet: PacketStatusClientbound)
|
||||
-> i32 {
|
||||
use PacketStatusClientbound::*;
|
||||
|
||||
let mut data = Vec::new();
|
||||
|
||||
match packet {
|
||||
Response(response) => {
|
||||
write_varint(&mut data, 0x00);
|
||||
write_slice(&mut data, serde_json::to_vec(&response).unwrap().as_slice());
|
||||
ser.write_json(&response);
|
||||
0x00
|
||||
},
|
||||
Pong(payload) => {
|
||||
write_varint(&mut data, 0x01);
|
||||
data.extend_from_slice(&payload);
|
||||
ser.write(&payload);
|
||||
0x01
|
||||
}
|
||||
}
|
||||
|
||||
let mut packet_length_buf = Vec::new();
|
||||
write_varint(&mut packet_length_buf, data.len() as i32);
|
||||
|
||||
dest.write_all(packet_length_buf.as_slice()).await?;
|
||||
dest.write_all(data.as_slice()).await?;
|
||||
dest.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
use crate::net::format::MAX_CLIENT_PACKET_SIZE;
|
||||
use serde::Serialize;
|
||||
|
||||
pub trait PacketSerializer {
|
||||
/// Write a slice of bytes directly, without a length prefix.
|
||||
fn write(&mut self, value: &[u8]);
|
||||
|
||||
fn write_u8(&mut self, value: u8) {
|
||||
self.write(&value.to_le_bytes());
|
||||
}
|
||||
|
||||
/// https://wiki.vg/Protocol#VarInt_and_VarLong
|
||||
fn write_varint(&mut self, mut value: i32) {
|
||||
loop {
|
||||
let mut temp = (value & 0b01111111) as u8;
|
||||
value = value >> 7;
|
||||
if value != 0 {
|
||||
temp |= 0b10000000;
|
||||
}
|
||||
self.write_u8(temp);
|
||||
|
||||
if value == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write a varint-length-prefixed byte slice.
|
||||
fn write_slice(&mut self, value: &[u8]) {
|
||||
self.write_varint(value.len() as i32);
|
||||
self.write(value);
|
||||
}
|
||||
|
||||
/// Write a varint-length-prefixed string.
|
||||
fn write_str(&mut self, value: &str) {
|
||||
self.write_slice(value.as_bytes());
|
||||
}
|
||||
|
||||
fn write_json(&mut self, value: &impl Serialize) {
|
||||
self.write_slice(serde_json::to_vec(value).unwrap().as_slice())
|
||||
}
|
||||
}
|
||||
|
||||
impl PacketSerializer for Vec<u8> {
|
||||
fn write(&mut self, value: &[u8]) {
|
||||
self.extend_from_slice(value);
|
||||
}
|
||||
|
||||
fn write_u8(&mut self, value: u8) {
|
||||
self.push(value);
|
||||
}
|
||||
}
|
||||
|
||||
pub trait PacketDeserializer {
|
||||
fn read(&mut self, buf: &mut [u8]) -> Result<(), String>;
|
||||
fn read_eof(&mut self) -> Result<(), String>;
|
||||
|
||||
fn read_u8(&mut self) -> Result<u8, String> {
|
||||
let mut buf = [0; 1];
|
||||
self.read(&mut buf)?;
|
||||
Ok(u8::from_le_bytes(buf))
|
||||
}
|
||||
|
||||
fn read_u16(&mut self) -> Result<u16, String> {
|
||||
let mut buf = [0; 2];
|
||||
self.read(&mut buf)?;
|
||||
Ok(u16::from_le_bytes(buf))
|
||||
}
|
||||
|
||||
fn read_varint(&mut self) -> Result<i32, String> {
|
||||
let mut length = 1;
|
||||
let mut acc = 0;
|
||||
// VarInts must not be longer than 5 bytes.
|
||||
while length <= 5 {
|
||||
// If the highest bit is set, there are further bytes to be read from this VarInt;
|
||||
// the rest of the bits are the actual data in the VarInt.
|
||||
let read = self.read_u8()?;
|
||||
acc |= (read & 0b01111111) as i32;
|
||||
|
||||
// There are no mo
|
||||
if (read & 0b10000000) == 0 {
|
||||
return Ok(acc);
|
||||
}
|
||||
|
||||
// Make space for the rest of the bits.
|
||||
acc = acc << 7;
|
||||
length += 1;
|
||||
}
|
||||
|
||||
// The VarInt was too long!
|
||||
Err("VarInt was more than 5 bytes.".to_string())
|
||||
}
|
||||
|
||||
fn read_string(&mut self) -> Result<String, String> {
|
||||
let length = self.read_varint()?;
|
||||
if length < 0 {
|
||||
return Err("String length cannot be negative.".to_string());
|
||||
}
|
||||
|
||||
let length = length as usize;
|
||||
if length > MAX_CLIENT_PACKET_SIZE {
|
||||
return Err("String was too long.".to_string());
|
||||
}
|
||||
|
||||
let mut buf = Vec::with_capacity(length);
|
||||
buf.resize(length, 0);
|
||||
self.read(buf.as_mut_slice())?;
|
||||
String::from_utf8(buf).map_err(|_| "String was invalid UTF-8.".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VecPacketDeserializer<'a> {
|
||||
data: &'a [u8],
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl VecPacketDeserializer<'_> {
|
||||
pub fn new<'a>(data: &'a [u8]) -> VecPacketDeserializer<'a> {
|
||||
VecPacketDeserializer {
|
||||
data: data,
|
||||
index: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PacketDeserializer for VecPacketDeserializer<'_> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> Result<(), String> {
|
||||
if self.index + buf.len() > self.data.len() {
|
||||
return Err("Tried to read past length of packet.".to_string());
|
||||
}
|
||||
|
||||
let len = buf.len();
|
||||
buf[..].copy_from_slice(&self.data[self.index..self.index + len]);
|
||||
self.index += buf.len();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn read_eof(&mut self) -> Result<(), String> {
|
||||
if self.index != self.data.len() {
|
||||
return Err("Packet contained more data than necessary.".to_string());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,132 +0,0 @@
|
|||
use tokio::io::AsyncReadExt;
|
||||
use tokio::net::tcp::ReadHalf;
|
||||
|
||||
const MAX_CLIENT_PACKET_SIZE: usize = 32767;
|
||||
|
||||
pub struct PacketSource<'a> {
|
||||
source: &'a mut ReadHalf<'a>,
|
||||
buf: [u8; MAX_CLIENT_PACKET_SIZE],
|
||||
index: usize,
|
||||
used: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum PacketError {
|
||||
IoError(std::io::Error),
|
||||
PktError(String),
|
||||
}
|
||||
|
||||
use PacketError::*;
|
||||
|
||||
pub type Result<A> = std::result::Result<A, PacketError>;
|
||||
|
||||
impl PacketSource<'_> {
|
||||
pub fn new<'a>(source: &'a mut ReadHalf<'a>) -> PacketSource<'a> {
|
||||
PacketSource {
|
||||
source: source,
|
||||
buf: [0; MAX_CLIENT_PACKET_SIZE],
|
||||
index: 0,
|
||||
used: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Come up with a more efficient way of buffering.
|
||||
pub async fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
|
||||
let mut index = 0;
|
||||
while buf.len() > index {
|
||||
let bytes_needed = buf.len() - index;
|
||||
// If we've already read as many bytes as we need, return.
|
||||
if bytes_needed == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
let bytes_remaining = self.used - self.index;
|
||||
// If we're out of bytes to read, read some more.
|
||||
if bytes_remaining == 0 {
|
||||
// If we've already used up the entire buffer, restart from the beginning.
|
||||
let space_remaining = self.buf.len() - self.used;
|
||||
if space_remaining < 1 {
|
||||
self.used = 0;
|
||||
self.index = 0;
|
||||
}
|
||||
|
||||
let len = self.buf.len();
|
||||
let read = match self.source.read(&mut self.buf[self.index..len]).await {
|
||||
Ok(read) => read,
|
||||
Err(err) => return Err(IoError(err))
|
||||
};
|
||||
|
||||
self.used += read;
|
||||
continue;
|
||||
}
|
||||
|
||||
let bytes_to_copy = bytes_remaining.min(bytes_needed);
|
||||
|
||||
buf[index..index + bytes_to_copy]
|
||||
.copy_from_slice(&self.buf[self.index..self.index + bytes_to_copy]);
|
||||
index += bytes_to_copy;
|
||||
self.index += bytes_to_copy;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn read_u8(&mut self) -> Result<u8> {
|
||||
let mut buf = [0; 1];
|
||||
self.read_exact(&mut buf).await?;
|
||||
Ok(buf[0])
|
||||
}
|
||||
|
||||
pub async fn read_u16(&mut self) -> Result<u16> {
|
||||
let mut buf = [0; 2];
|
||||
self.read_exact(&mut buf).await?;
|
||||
Ok(u16::from_le_bytes(buf))
|
||||
}
|
||||
|
||||
pub async fn read_i64(&mut self) -> Result<i64> {
|
||||
let mut buf = [0; 8];
|
||||
self.read_exact(&mut buf).await?;
|
||||
Ok(i64::from_le_bytes(buf))
|
||||
}
|
||||
|
||||
pub async fn read_varint(&mut self) -> Result<i32> {
|
||||
let mut length = 1;
|
||||
let mut acc = 0;
|
||||
// VarInts must not be longer than 5 bytes.
|
||||
while length <= 5 {
|
||||
// If the highest bit is set, there are further bytes to be read from this VarInt;
|
||||
// the rest of the bits are the actual data in the VarInt.
|
||||
let read = self.read_u8().await?;
|
||||
acc |= (read & 0b01111111) as i32;
|
||||
|
||||
// There are no mo
|
||||
if (read & 0b10000000) == 0 {
|
||||
return Ok(acc);
|
||||
}
|
||||
|
||||
// Make space for the rest of the bits.
|
||||
acc = acc << 7;
|
||||
length += 1;
|
||||
}
|
||||
|
||||
// The VarInt was too long!
|
||||
Err(PktError("VarInt was more than 5 bytes.".to_string()))
|
||||
}
|
||||
|
||||
pub async fn read_string(&mut self) -> Result<String> {
|
||||
let length = self.read_varint().await?;
|
||||
if length < 0 {
|
||||
return Err(PktError("String length cannot be negative.".to_string()));
|
||||
}
|
||||
|
||||
let length = length as usize;
|
||||
if length > MAX_CLIENT_PACKET_SIZE {
|
||||
return Err(PktError("String was too long.".to_string()));
|
||||
}
|
||||
|
||||
let mut buf = Vec::new();
|
||||
buf.resize(length, 0);
|
||||
self.read_exact(buf.as_mut_slice()).await?;
|
||||
String::from_utf8(buf).map_err(|_| PktError("String was invalid UTF-8.".to_string()))
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue