Got encryption working! It's ugly, but it works. Cleanup tomorrow.
parent
3701b59a11
commit
60c9a48293
|
@ -1 +1,2 @@
|
|||
/target
|
||||
/private
|
||||
|
|
File diff suppressed because it is too large
Load Diff
15
Cargo.toml
15
Cargo.toml
|
@ -7,13 +7,22 @@ repository = "https://github.com/jamestmartin/tmd"
|
|||
license = "GPL-3.0+"
|
||||
publish = false
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.36"
|
||||
clap = { version = "2.33.1", features = ["yaml"] }
|
||||
flate2 = "1.0.16"
|
||||
serde = { version = "1.0.114", features = ["derive"] }
|
||||
serde_json = "1.0.56"
|
||||
tokio = { version = "0.2.22", features = ["io-util", "macros", "net", "tcp", "rt-threaded"] }
|
||||
uuid = { version = "0.8", features = ["serde"] }
|
||||
|
||||
# Dependencies required for compression
|
||||
flate2 = "1.0.16"
|
||||
|
||||
# Dependencies required for encryption
|
||||
aes = "0.4.0"
|
||||
cfb8 = "0.4.0"
|
||||
num-bigint = "0.3.0"
|
||||
rand = "0.7.3"
|
||||
reqwest = "0.10.7"
|
||||
rsa = "0.3.0"
|
||||
sha-1 = "0.9.1"
|
||||
|
|
66
src/main.rs
66
src/main.rs
|
@ -124,7 +124,71 @@ async fn interact_login(mut con: Connection<Login>) -> io::Result<()> {
|
|||
}
|
||||
};
|
||||
|
||||
eprintln!("Client set username to {}.", name);
|
||||
use rand::Rng;
|
||||
use rand::rngs::OsRng;
|
||||
use rsa::{RSAPrivateKey, PaddingScheme};
|
||||
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
let mut public_key = Vec::new();
|
||||
File::open("private/pub.der").expect("missing public key").read_to_end(&mut public_key)?;
|
||||
|
||||
let mut private_key = Vec::new();
|
||||
File::open("private/priv.der").expect("missing private key").read_to_end(&mut private_key)?;
|
||||
|
||||
let key = RSAPrivateKey::from_pkcs1(&private_key).expect("Invalid private key.");
|
||||
|
||||
let mut verify_token = Vec::new();
|
||||
verify_token.resize(4, 0u8);
|
||||
OsRng.fill(verify_token.as_mut_slice());
|
||||
|
||||
let server_id = "";
|
||||
|
||||
con.write(&Clientbound::EncryptionRequest(EncryptionRequest {
|
||||
server_id: server_id.to_string().into_boxed_str(),
|
||||
public_key: public_key.clone().into_boxed_slice(),
|
||||
verify_token: verify_token.clone().into_boxed_slice(),
|
||||
})).await?;
|
||||
|
||||
let secret = match con.read().await? {
|
||||
Serverbound::EncryptionResponse(encryption_response) => {
|
||||
let token = key.decrypt(PaddingScheme::PKCS1v15Encrypt,
|
||||
&encryption_response.verify_token)
|
||||
.expect("Failed to decrypt verify token.");
|
||||
if token.as_slice() != verify_token.as_slice() {
|
||||
return mk_err("Incorrect verify token.");
|
||||
}
|
||||
|
||||
key.decrypt(PaddingScheme::PKCS1v15Encrypt, &encryption_response.shared_secret)
|
||||
.expect("Failed to decrypt shared secret.")
|
||||
},
|
||||
_ => {
|
||||
return mk_err("Unexpected packet (expected Encryption Response).");
|
||||
}
|
||||
};
|
||||
|
||||
con = con.set_encryption(&secret).expect("Failed to set encryption.");
|
||||
|
||||
use reqwest::Client;
|
||||
|
||||
let server_hash = {
|
||||
let server_hash_bytes = {
|
||||
use sha1::{Sha1, Digest};
|
||||
let mut hasher = Sha1::new();
|
||||
hasher.update(server_id.as_bytes());
|
||||
hasher.update(&secret);
|
||||
hasher.update(&public_key);
|
||||
hasher.finalize()
|
||||
};
|
||||
|
||||
format!("{:x}", num_bigint::BigInt::from_signed_bytes_be(&server_hash_bytes))
|
||||
};
|
||||
|
||||
// TODO: Authentication, not just encryption.
|
||||
println!("{:?}", Client::new().get("https://sessionserver.mojang.com/session/minecraft/hasJoined")
|
||||
.header("Content-Type", "application/json")
|
||||
.query(&[("username", name.clone()), ("serverId", server_hash.into_boxed_str())])
|
||||
.send().await.expect("Request failed.").text().await.unwrap());
|
||||
|
||||
con.set_compression(Some(64)).await?;
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
mod packet_format;
|
||||
mod stream;
|
||||
|
||||
use crate::net::{Reader, Writer};
|
||||
use crate::net::connection::packet_format::PacketFormat;
|
||||
use crate::net::connection::packet_format::default::DefaultPacketFormat;
|
||||
use crate::net::connection::stream::Stream;
|
||||
use crate::net::protocol::packet_map::PacketMap;
|
||||
use crate::net::protocol::state::ProtocolState;
|
||||
use crate::net::protocol::state::handshake::Handshake;
|
||||
|
@ -11,27 +12,29 @@ use crate::net::protocol::state::play::Play;
|
|||
use crate::net::protocol::state::status::Status;
|
||||
use std::io;
|
||||
use std::marker::PhantomData;
|
||||
use tokio::io::BufStream;
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
pub struct Connection<St: ProtocolState> {
|
||||
src: Reader,
|
||||
dest: Writer,
|
||||
rw: Box<dyn Stream>,
|
||||
fmt: Box<dyn PacketFormat>,
|
||||
st: PhantomData<St>,
|
||||
}
|
||||
|
||||
impl<St: ProtocolState> Connection<St> {
|
||||
pub async fn write(&mut self, pkt: &St::Clientbound) -> io::Result<()> {
|
||||
let mut buf = Vec::new();
|
||||
pkt.write(&mut buf);
|
||||
// Turn the packet into bytes.
|
||||
let mut contents = Vec::new();
|
||||
pkt.write(&mut contents);
|
||||
|
||||
self.fmt.send(&mut self.dest, buf.as_ref()).await
|
||||
// Send the packet with the appropriate header.
|
||||
self.fmt.send(&mut self.rw, &contents).await
|
||||
}
|
||||
|
||||
pub async fn read(&mut self) -> io::Result<St::Serverbound> {
|
||||
use crate::net::serialize::VecPacketDeserializer;
|
||||
|
||||
let buf = self.fmt.recieve(&mut self.src).await?;
|
||||
let buf = self.fmt.recieve(&mut self.rw).await?;
|
||||
|
||||
St::Serverbound::read(&mut VecPacketDeserializer::new(buf.as_ref()))
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))
|
||||
|
@ -39,8 +42,7 @@ impl<St: ProtocolState> Connection<St> {
|
|||
|
||||
fn into_state<NewSt: ProtocolState>(self) -> Connection<NewSt> {
|
||||
Connection {
|
||||
src: self.src,
|
||||
dest: self.dest,
|
||||
rw: self.rw,
|
||||
fmt: self.fmt,
|
||||
st: PhantomData,
|
||||
}
|
||||
|
@ -53,13 +55,8 @@ impl<St: ProtocolState> Connection<St> {
|
|||
|
||||
impl Connection<Handshake> {
|
||||
pub fn new(stream: TcpStream) -> Self {
|
||||
use tokio::io::{BufReader, BufWriter};
|
||||
|
||||
let (src, dest) = stream.into_split();
|
||||
|
||||
Connection {
|
||||
src: BufReader::new(src),
|
||||
dest: BufWriter::new(dest),
|
||||
rw: Box::new(BufStream::new(stream)),
|
||||
fmt: Box::new(DefaultPacketFormat),
|
||||
st: PhantomData,
|
||||
}
|
||||
|
@ -90,7 +87,7 @@ impl Connection<Login> {
|
|||
// Further packets will use the new compression threshold.
|
||||
match threshold {
|
||||
Some(threshold) => {
|
||||
self.fmt = Box::new(CompressedPacketFormat { threshold: threshold as usize });
|
||||
self.fmt = Box::new(CompressedPacketFormat::new(threshold as usize));
|
||||
},
|
||||
None => {
|
||||
self.fmt = Box::new(DefaultPacketFormat);
|
||||
|
@ -100,6 +97,22 @@ impl Connection<Login> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// WARNING: This function is not idempontent.
|
||||
/// Calling it twice will result in the underlying stream getting encrypted twice.
|
||||
pub fn set_encryption(self, secret: &[u8]) -> Result<Self, String> {
|
||||
use cfb8::Cfb8;
|
||||
use cfb8::stream_cipher::NewStreamCipher;
|
||||
use crate::net::connection::stream::encrypted::EncryptedStream;
|
||||
|
||||
let cipher: Cfb8<aes::Aes128> = Cfb8::new_var(secret, secret).map_err(|err| err.to_string())?;
|
||||
|
||||
Ok(Connection {
|
||||
rw: Box::new(EncryptedStream::new(self.rw, cipher)),
|
||||
fmt: self.fmt,
|
||||
st: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn into_play(self) -> Connection<Play> {
|
||||
self.into_state()
|
||||
}
|
||||
|
|
|
@ -2,31 +2,33 @@ pub mod compressed;
|
|||
pub mod default;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use crate::net::{Reader, Writer};
|
||||
use std::io;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||
|
||||
pub type Reader = dyn AsyncRead + Unpin + Send;
|
||||
pub type Writer = dyn AsyncWrite + Unpin + Send;
|
||||
|
||||
#[async_trait]
|
||||
pub trait PacketFormat: Send + Sync {
|
||||
async fn send(&self, dest: &mut Writer, data: &[u8]) -> io::Result<()>;
|
||||
async fn recieve(&self, src: &mut Reader) -> io::Result<Box<[u8]>>;
|
||||
async fn send(&self, dest: &mut Writer, data: &[u8]) -> io::Result<()>;
|
||||
}
|
||||
|
||||
pub const MAX_CLIENT_PACKET_SIZE: usize = 32767;
|
||||
/// A completely arbitrary limitation on the maximum size of a recieved packet.
|
||||
pub const MAX_PACKET_SIZE: usize = 35565;
|
||||
|
||||
pub async fn read_varint(src: &mut Reader) -> io::Result<(usize, i32)> {
|
||||
let mut length = 1;
|
||||
let mut num_read: usize = 0;
|
||||
let mut acc = 0;
|
||||
while length <= 5 {
|
||||
while num_read < 5 {
|
||||
let byte = src.read_u8().await?;
|
||||
acc |= (byte & 0b01111111) as i32;
|
||||
acc |= ((byte & 0b01111111) as i32) << num_read * 7;
|
||||
|
||||
num_read += 1;
|
||||
|
||||
if byte & 0b10000000 == 0 {
|
||||
return Ok((length, acc));
|
||||
return Ok((num_read, acc));
|
||||
}
|
||||
|
||||
acc = acc << 7;
|
||||
length += 1;
|
||||
}
|
||||
|
||||
Err(io::Error::new(io::ErrorKind::Other, "VarInt was too long.".to_string()))
|
||||
|
|
|
@ -1,69 +1,44 @@
|
|||
use async_trait::async_trait;
|
||||
use crate::net::{Reader, Writer};
|
||||
use crate::net::connection::packet_format::{PacketFormat, MAX_CLIENT_PACKET_SIZE, read_varint};
|
||||
use crate::net::connection::packet_format::
|
||||
{PacketFormat, Reader, Writer, MAX_PACKET_SIZE, read_varint};
|
||||
use std::boxed::Box;
|
||||
use std::io;
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
pub struct CompressedPacketFormat {
|
||||
pub threshold: usize,
|
||||
threshold: usize,
|
||||
}
|
||||
|
||||
impl CompressedPacketFormat {
|
||||
pub fn new(threshold: usize) -> Self {
|
||||
Self {
|
||||
threshold: threshold,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A compressed header is in this format:
|
||||
//
|
||||
// packet_length: VarInt
|
||||
// uncompressed_length: VarInt
|
||||
// data: [u8]
|
||||
//
|
||||
// The packet length is the size of the entire packet in bytes,
|
||||
// including the uncompressed length.
|
||||
// The uncompressed length is the size of the uncompressed data in bytes,
|
||||
// or if it is zero, indicates that the data is not compressed.
|
||||
// This is followed by the data, either compressed or uncompressed.
|
||||
|
||||
#[async_trait]
|
||||
impl PacketFormat for CompressedPacketFormat {
|
||||
async fn send(&self, dest: &mut Writer, uncompressed_data: &[u8]) -> io::Result<()> {
|
||||
use crate::net::serialize::{PacketSerializer, VarInt};
|
||||
|
||||
// If the length of the uncompressed packet is less than the threshold,
|
||||
// then we do not compress the packet and set the data_length field to 0.
|
||||
// Otherwise, data_length is set to the length of the uncompressed packet.
|
||||
let will_compress = uncompressed_data.len() >= self.threshold;
|
||||
|
||||
let data_length = if will_compress { uncompressed_data.len() } else { 0 };
|
||||
let mut data_length_buf = Vec::with_capacity(5);
|
||||
data_length_buf.write(VarInt(data_length as i32));
|
||||
|
||||
let mut compression_buf;
|
||||
let data = if will_compress {
|
||||
use flate2::{Compress, Compression, FlushCompress};
|
||||
|
||||
// 1024 is just an arbitrary amount of extra space reserved
|
||||
// in case the output data ends up larger than the input data
|
||||
// (e.g. due to the zlib header).
|
||||
// FIXME: Further research to figure out the exact maximum capacity necessary.
|
||||
// Perhaps you only need space for the header and the data itself can't get bigger?
|
||||
// And what is the limit to how much bigger the data will get?
|
||||
// Currently I don't actually know for a fact that this won't ever drop data.
|
||||
compression_buf = Vec::with_capacity(1024 + uncompressed_data.len());
|
||||
Compress::new(Compression::best(), true)
|
||||
.compress_vec(uncompressed_data, &mut compression_buf, FlushCompress::Finish)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
|
||||
compression_buf.as_slice()
|
||||
} else {
|
||||
uncompressed_data
|
||||
};
|
||||
|
||||
let mut packet_length_buf = Vec::with_capacity(5);
|
||||
packet_length_buf.write(VarInt((data_length_buf.len() + data.len()) as i32));
|
||||
|
||||
{
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
dest.write(packet_length_buf.as_slice()).await?;
|
||||
dest.write(data_length_buf.as_slice()).await?;
|
||||
dest.write(data).await?;
|
||||
dest.flush().await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recieve(&self, src: &mut Reader) -> io::Result<Box<[u8]>> {
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
// First we read in the packet and uncompressed data lengths.
|
||||
let (_, packet_length) = read_varint(src).await?;
|
||||
if packet_length < 0 {
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "Packet length was negative."));
|
||||
}
|
||||
if packet_length > MAX_CLIENT_PACKET_SIZE as i32 {
|
||||
if packet_length > MAX_PACKET_SIZE as i32 {
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "Packet was too long."));
|
||||
}
|
||||
let packet_length = packet_length as usize;
|
||||
|
@ -72,30 +47,88 @@ impl PacketFormat for CompressedPacketFormat {
|
|||
if data_length < 0 {
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "Data length was negative."));
|
||||
}
|
||||
if data_length > MAX_CLIENT_PACKET_SIZE as i32 {
|
||||
if data_length > MAX_PACKET_SIZE as i32 {
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "Data was too long."));
|
||||
}
|
||||
let data_length = data_length as usize;
|
||||
|
||||
let mut buf = Vec::with_capacity(packet_length);
|
||||
buf.resize(packet_length, 0);
|
||||
src.read_exact(buf.as_mut_slice()).await?;
|
||||
// Now we recieve the remainder of the packet's data.
|
||||
let mut data = Vec::with_capacity(packet_length - data_length_size);
|
||||
data.resize(packet_length, 0);
|
||||
src.read_exact(data.as_mut_slice()).await?;
|
||||
|
||||
let decompressed_buf = if data_length != 0 {
|
||||
use flate2::{Decompress, FlushDecompress};
|
||||
// If the data was not compressed, we simply return it.
|
||||
if data_length == 0 {
|
||||
return Ok(data.into_boxed_slice());
|
||||
}
|
||||
|
||||
let mut decompressed_buf = Vec::with_capacity(data_length);
|
||||
decompressed_buf.resize(data_length, 0);
|
||||
Decompress::new(true)
|
||||
.decompress(&buf, decompressed_buf.as_mut_slice(), FlushDecompress::Finish)
|
||||
// Otherwise, we decompress it.
|
||||
let mut decompressed = Vec::new();
|
||||
decompressed.resize(data_length, 0);
|
||||
|
||||
use flate2::{Decompress, FlushDecompress};
|
||||
Decompress::new(true)
|
||||
.decompress(&data, decompressed.as_mut_slice(), FlushDecompress::Finish)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
|
||||
|
||||
Ok(decompressed.into_boxed_slice())
|
||||
}
|
||||
|
||||
async fn send(&self, dest: &mut Writer, data: &[u8]) -> io::Result<()> {
|
||||
use crate::net::serialize::{PacketSerializer, VarInt};
|
||||
|
||||
// If the length of the uncompressed data exceeds the threshold,
|
||||
// then we will compress this packet.
|
||||
if data.len() >= self.threshold {
|
||||
// Now we compress the data.
|
||||
use flate2::{Compress, FlushCompress};
|
||||
|
||||
// 1024 is just an arbitrary amount of extra space reserved
|
||||
// in case the output data ends up larger than the input data
|
||||
// (e.g. due to the zlib header).
|
||||
// FIXME: Further research to figure out the exact maximum capacity necessary.
|
||||
// Perhaps you only need space for the header and the data itself can't get bigger?
|
||||
// And what is the limit to how much bigger the data will get?
|
||||
// Currently I don't actually know for a fact that this won't ever drop data.
|
||||
let mut compressed = Vec::with_capacity(1024 + data.len());
|
||||
Compress::new(flate2::Compression::best(), true)
|
||||
.compress_vec(data, &mut compressed, FlushCompress::Finish)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
|
||||
decompressed_buf
|
||||
} else {
|
||||
let mut decompressed_buf = Vec::with_capacity(packet_length - data_length_size);
|
||||
decompressed_buf.copy_from_slice(&buf[data_length_size..]);
|
||||
decompressed_buf
|
||||
};
|
||||
|
||||
Ok(decompressed_buf.into_boxed_slice())
|
||||
// Since the packet is compressed,
|
||||
// data_length will be the length of the uncompressed data.
|
||||
let mut data_length_buf = Vec::with_capacity(5);
|
||||
data_length_buf.write(VarInt(data.len() as i32));
|
||||
|
||||
let mut packet_length_buf = Vec::with_capacity(5);
|
||||
packet_length_buf.write(VarInt((data_length_buf.len() + compressed.len()) as i32));
|
||||
|
||||
{
|
||||
// I have to keep this import in a block so that
|
||||
// it won't conflict with PacketSerialize::write.
|
||||
use tokio::io::AsyncWriteExt;
|
||||
dest.write(packet_length_buf.as_slice()).await?;
|
||||
dest.write(data_length_buf.as_slice()).await?;
|
||||
dest.write(compressed.as_slice()).await?;
|
||||
dest.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
} else {
|
||||
// Since the packet is uncompressed,
|
||||
// the packet length is just the length of the data plus the data_length,
|
||||
// which will just be 0x00 (1 byte long) because the data isn't compressed.
|
||||
let mut packet_length_buf = Vec::with_capacity(5);
|
||||
packet_length_buf.write(VarInt(data.len() as i32 + 1));
|
||||
|
||||
{
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
dest.write(packet_length_buf.as_slice()).await?;
|
||||
dest.write_u8(0x00).await?;
|
||||
dest.write(data).await?;
|
||||
dest.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,14 +1,29 @@
|
|||
use async_trait::async_trait;
|
||||
use crate::net::{Reader, Writer};
|
||||
use crate::net::connection::packet_format::{PacketFormat, MAX_CLIENT_PACKET_SIZE, read_varint};
|
||||
use crate::net::connection::packet_format::
|
||||
{PacketFormat, Reader, Writer, MAX_PACKET_SIZE, read_varint};
|
||||
use std::boxed::Box;
|
||||
use std::io;
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
pub struct DefaultPacketFormat;
|
||||
|
||||
#[async_trait]
|
||||
impl PacketFormat for DefaultPacketFormat {
|
||||
async fn recieve(&self, src: &mut Reader) -> io::Result<Box<[u8]>> {
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
let (_, length) = read_varint(src).await?;
|
||||
if length > MAX_PACKET_SIZE as i32 {
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "Packet was too long.".to_string()));
|
||||
}
|
||||
let length = length as usize;
|
||||
|
||||
let mut buf = Vec::with_capacity(length);
|
||||
buf.resize(length, 0);
|
||||
src.read_exact(buf.as_mut_slice()).await?;
|
||||
|
||||
Ok(buf.into_boxed_slice())
|
||||
}
|
||||
|
||||
async fn send(&self, dest: &mut Writer, data: &[u8]) -> io::Result<()> {
|
||||
use crate::net::serialize::{PacketSerializer, VarInt};
|
||||
|
||||
|
@ -25,18 +40,4 @@ impl PacketFormat for DefaultPacketFormat {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recieve(&self, src: &mut Reader) -> io::Result<Box<[u8]>> {
|
||||
let (_, length) = read_varint(src).await?;
|
||||
if length > MAX_CLIENT_PACKET_SIZE as i32 {
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "Packet was too long.".to_string()));
|
||||
}
|
||||
let length = length as usize;
|
||||
|
||||
let mut buf = Vec::with_capacity(length);
|
||||
buf.resize(length, 0);
|
||||
src.read_exact(buf.as_mut_slice()).await?;
|
||||
|
||||
Ok(buf.into_boxed_slice())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
pub mod encrypted;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
pub trait Stream: AsyncRead + AsyncWrite + Send + Unpin { }
|
||||
impl<S: AsyncRead + AsyncWrite + Send + Unpin> Stream for S { }
|
|
@ -0,0 +1,88 @@
|
|||
use aes::Aes128;
|
||||
use cfb8::Cfb8;
|
||||
use cfb8::stream_cipher::StreamCipher;
|
||||
use crate::net::connection::stream::Stream;
|
||||
use std::pin::Pin;
|
||||
use std::task::Poll;
|
||||
use std::task::Context;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, Result};
|
||||
|
||||
pub struct EncryptedStream {
|
||||
rw: Box<dyn Stream>,
|
||||
cipher: Cfb8<Aes128>,
|
||||
write_buf: Vec<u8>,
|
||||
}
|
||||
|
||||
impl EncryptedStream {
|
||||
pub fn new(rw: Box<dyn Stream>, cipher: Cfb8<Aes128>) -> Self {
|
||||
Self {
|
||||
rw: rw,
|
||||
cipher: cipher,
|
||||
write_buf: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self, cx: &mut Context) -> Poll<Result<()>> {
|
||||
// We don't know when the internal writer will be ready,
|
||||
// so we have to coax it into returning "pending" and scheduling an interrupt for us.
|
||||
// Either that, or we finish writing our buffer and flush.
|
||||
loop {
|
||||
if self.write_buf.len() == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
match Pin::new(&mut self.rw).poll_write(cx, &self.write_buf) {
|
||||
Poll::Ready(Ok(length)) => {
|
||||
let mut new_buf = Vec::new();
|
||||
new_buf.copy_from_slice(&self.write_buf[length..]);
|
||||
self.write_buf = new_buf;
|
||||
},
|
||||
other => return other.map(|x| x.map(|_| ()))
|
||||
}
|
||||
}
|
||||
|
||||
Pin::new(&mut self.rw).poll_flush(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for EncryptedStream {
|
||||
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<Result<usize>> {
|
||||
let me = Pin::into_inner(self);
|
||||
|
||||
match Pin::new(&mut me.rw).poll_read(cx, buf) {
|
||||
Poll::Ready(Ok(bytes)) => {
|
||||
me.cipher.decrypt(&mut buf[..bytes]);
|
||||
Poll::Ready(Ok(bytes))
|
||||
},
|
||||
other => other
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for EncryptedStream {
|
||||
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context, buf: &[u8]) -> Poll<Result<usize>> {
|
||||
let me = Pin::into_inner(self);
|
||||
|
||||
let index = me.write_buf.len();
|
||||
// Copy data to our write buffer and then encrypt it.
|
||||
me.write_buf.extend_from_slice(buf);
|
||||
me.cipher.encrypt(&mut me.write_buf[index..]);
|
||||
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
|
||||
Pin::into_inner(self).flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
|
||||
let me = Pin::into_inner(self);
|
||||
|
||||
match me.flush(cx) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
Pin::new(&mut me.rw).poll_shutdown(cx)
|
||||
},
|
||||
other => other
|
||||
}
|
||||
}
|
||||
}
|
|
@ -2,9 +2,3 @@ pub mod chat;
|
|||
pub mod connection;
|
||||
pub mod protocol;
|
||||
pub mod serialize;
|
||||
|
||||
use tokio::io::{BufReader, BufWriter};
|
||||
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||
|
||||
pub type Reader = BufReader<OwnedReadHalf>;
|
||||
pub type Writer = BufWriter<OwnedWriteHalf>;
|
||||
|
|
|
@ -142,13 +142,15 @@ macro_rules! impl_varnum {
|
|||
|
||||
impl PacketReadable for $name {
|
||||
fn read(deser: &mut impl PacketDeserializer) -> Result<Self, String> {
|
||||
let mut length = 1;
|
||||
let mut num_read: usize = 0;
|
||||
let mut acc = 0;
|
||||
while length <= $length {
|
||||
while num_read < $length {
|
||||
// If the highest bit is set, there are further bytes to be read;
|
||||
// the rest of the bits are the actual bits of the number.
|
||||
let read = deser.read::<u8>()?;
|
||||
acc |= (read & 0b01111111) as $wraps;
|
||||
acc |= ((read & 0b01111111) as $wraps) << num_read * 7;
|
||||
|
||||
num_read += 1;
|
||||
|
||||
if (read & 0b10000000) == 0 {
|
||||
// There are no more bytes.
|
||||
|
@ -157,7 +159,6 @@ macro_rules! impl_varnum {
|
|||
|
||||
// Make space for the rest of the bits.
|
||||
acc = acc << 7;
|
||||
length += 1;
|
||||
}
|
||||
|
||||
Err(format!("VarNum was more than {} bytes.", $length))
|
||||
|
|
Loading…
Reference in New Issue