use std::convert::TryFrom;
use thiserror::Error;
#[derive(Debug, Error)]
enum Error {
#[error("์๋ชป๋ varint")]
InvalidVarint,
#[error("์๋ชป๋ wire-type")]
InvalidWireType,
#[error("์์์น ๋ชปํ EOF")]
UnexpectedEOF,
#[error("์๋ชป๋ ๊ธธ์ด")]
InvalidSize(#[from] std::num::TryFromIntError),
#[error("์์์น ๋ชปํ wire-type")]
UnexpectedWireType,
#[error("์๋ชป๋ ๋ฌธ์์ด(UTF-8 ์๋)")]
InvalidString,
}
/// ์์ด์ด์ ํ์๋ ์์ด์ด ํ์
์
๋๋ค.
enum WireType {
/// Varint WireType์ ๊ฐ์ด ๋จ์ผ VARINT์์ ๋ํ๋
๋๋ค.
Varint,
//I64, -- not needed for this exercise
/// The Len WireType indicates that the value is a length represented as a
/// VARINT followed by exactly that number of bytes.
Len,
/// The I32 WireType indicates that the value is precisely 4 bytes in
/// little-endian order containing a 32-bit signed integer.
I32,
}
#[derive(Debug)]
/// ์์ด์ด ํ์
์ ๋ฐ๋ผ ํ์
์ด ์ง์ ๋ ํ๋ ๊ฐ์
๋๋ค.
enum FieldValue<'a> {
Varint(u64),
//I64(i64), -- ์ด ์ฐ์ต์์๋ ํ์ํ์ง ์์ต๋๋ค.
Len(&'a [u8]),
I32(i32),
}
#[derive(Debug)]
/// ํ๋ ๋ฒํธ ๋ฐ ๊ฐ์ ํฌํจํ๋ ํ๋์
๋๋ค.
struct Field<'a> {
field_num: u64,
value: FieldValue<'a>,
}
trait ProtoMessage<'a>: Default + 'a {
fn add_field(&mut self, field: Field<'a>) -> Result<(), Error>;
}
impl TryFrom<u64> for WireType {
type Error = Error;
fn try_from(value: u64) -> Result<WireType, Error> {
Ok(match value {
0 => WireType::Varint,
//1 => WireType::I64, -- ์ด ์ฐ์ต์์๋ ํ์ํ์ง ์์ต๋๋ค.
2 => WireType::Len,
5 => WireType::I32,
_ => return Err(Error::InvalidWireType),
})
}
}
impl<'a> FieldValue<'a> {
fn as_string(&self) -> Result<&'a str, Error> {
let FieldValue::Len(data) = self else {
return Err(Error::UnexpectedWireType);
};
std::str::from_utf8(data).map_err(|_| Error::InvalidString)
}
fn as_bytes(&self) -> Result<&'a [u8], Error> {
let FieldValue::Len(data) = self else {
return Err(Error::UnexpectedWireType);
};
Ok(data)
}
fn as_u64(&self) -> Result<u64, Error> {
let FieldValue::Varint(value) = self else {
return Err(Error::UnexpectedWireType);
};
Ok(*value)
}
}
/// VARINT๋ฅผ ํ์ฑํ์ฌ ํ์ฑ๋ ๊ฐ๊ณผ ๋๋จธ์ง ๋ฐ์ดํธ๋ฅผ ๋ฐํํฉ๋๋ค.
fn parse_varint(data: &[u8]) -> Result<(u64, &[u8]), Error> {
for i in 0..7 {
let Some(b) = data.get(i) else {
return Err(Error::InvalidVarint);
};
if b & 0x80 == 0 {
// ์ด๋ VARINT์ ๋ง์ง๋ง ๋ฐ์ดํธ์ด๋ฏ๋ก
// u64๋ก ๋ณํํ์ฌ ๋ฐํํฉ๋๋ค.
let mut value = 0u64;
for b in data[..=i].iter().rev() {
value = (value << 7) | (b & 0x7f) as u64;
}
return Ok((value, &data[i + 1..]));
}
}
// 7๋ฐ์ดํธ๋ฅผ ์ด๊ณผํ๋ฉด ์ ํจํ์ง ์์ต๋๋ค.
Err(Error::InvalidVarint)
}
/// ํ๊ทธ๋ฅผ ํ๋ ๋ฒํธ์ WireType์ผ๋ก ๋ณํํฉ๋๋ค.
fn unpack_tag(tag: u64) -> Result<(u64, WireType), Error> {
let field_num = tag >> 3;
let wire_type = WireType::try_from(tag & 0x7)?;
Ok((field_num, wire_type))
}
/// ํ๋๋ฅผ ํ์ฑํ์ฌ ๋๋จธ์ง ๋ฐ์ดํธ๋ฅผ ๋ฐํํฉ๋๋ค.
fn parse_field(data: &[u8]) -> Result<(Field, &[u8]), Error> {
let (tag, remainder) = parse_varint(data)?;
let (field_num, wire_type) = unpack_tag(tag)?;
let (fieldvalue, remainder) = match wire_type {
WireType::Varint => {
let (value, remainder) = parse_varint(remainder)?;
(FieldValue::Varint(value), remainder)
}
WireType::Len => {
let (len, remainder) = parse_varint(remainder)?;
let len: usize = len.try_into()?;
if remainder.len() < len {
return Err(Error::UnexpectedEOF);
}
let (value, remainder) = remainder.split_at(len);
(FieldValue::Len(value), remainder)
}
WireType::I32 => {
if remainder.len() < 4 {
return Err(Error::UnexpectedEOF);
}
let (value, remainder) = remainder.split_at(4);
// `value`์ ๊ธธ์ด๊ฐ 4๋ฐ์ดํธ์ด๋ฏ๋ก ์ค๋ฅ๋ฅผ ๋ํ ํด์ ํฉ๋๋ค.
let value = i32::from_le_bytes(value.try_into().unwrap());
(FieldValue::I32(value), remainder)
}
};
Ok((Field { field_num, value: fieldvalue }, remainder))
}
/// ์ฃผ์ด์ง ๋ฐ์ดํฐ์์ ๋ฉ์์ง๋ฅผ ํ์ฑํ์ฌ ๋ฉ์์ง์ ๊ฐ ํ๋์ ๋ํด `T::add_field`๋ฅผ ํธ์ถํฉ๋๋ค.
///
/// ์ ์ฒด ์
๋ ฅ์ด ์ฌ์ฉ๋ฉ๋๋ค.
fn parse_message<'a, T: ProtoMessage<'a>>(mut data: &'a [u8]) -> Result<T, Error> {
let mut result = T::default();
while !data.is_empty() {
let parsed = parse_field(data)?;
result.add_field(parsed.0)?;
data = parsed.1;
}
Ok(result)
}
#[derive(Debug, Default)]
struct PhoneNumber<'a> {
number: &'a str,
type_: &'a str,
}
#[derive(Debug, Default)]
struct Person<'a> {
name: &'a str,
id: u64,
phone: Vec<PhoneNumber<'a>>,
}
impl<'a> ProtoMessage<'a> for Person<'a> {
fn add_field(&mut self, field: Field<'a>) -> Result<(), Error> {
match field.field_num {
1 => self.name = field.value.as_string()?,
2 => self.id = field.value.as_u64()?,
3 => self.phone.push(parse_message(field.value.as_bytes()?)?),
_ => {} // ๋๋จธ์ง๋ ๋ชจ๋ ๊ฑด๋๋๋๋ค.
}
Ok(())
}
}
impl<'a> ProtoMessage<'a> for PhoneNumber<'a> {
fn add_field(&mut self, field: Field<'a>) -> Result<(), Error> {
match field.field_num {
1 => self.number = field.value.as_string()?,
2 => self.type_ = field.value.as_string()?,
_ => {} // ๋๋จธ์ง๋ ๋ชจ๋ ๊ฑด๋๋๋๋ค.
}
Ok(())
}
}
fn main() {
let person: Person = parse_message(&[
0x0a, 0x07, 0x6d, 0x61, 0x78, 0x77, 0x65, 0x6c, 0x6c, 0x10, 0x2a, 0x1a,
0x16, 0x0a, 0x0e, 0x2b, 0x31, 0x32, 0x30, 0x32, 0x2d, 0x35, 0x35, 0x35,
0x2d, 0x31, 0x32, 0x31, 0x32, 0x12, 0x04, 0x68, 0x6f, 0x6d, 0x65, 0x1a,
0x18, 0x0a, 0x0e, 0x2b, 0x31, 0x38, 0x30, 0x30, 0x2d, 0x38, 0x36, 0x37,
0x2d, 0x35, 0x33, 0x30, 0x38, 0x12, 0x06, 0x6d, 0x6f, 0x62, 0x69, 0x6c,
0x65,
])
.unwrap();
println!("{:#?}", person);
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn as_string() {
assert!(FieldValue::Varint(10).as_string().is_err());
assert!(FieldValue::I32(10).as_string().is_err());
assert_eq!(FieldValue::Len(b"hello").as_string().unwrap(), "hello");
}
#[test]
fn as_bytes() {
assert!(FieldValue::Varint(10).as_bytes().is_err());
assert!(FieldValue::I32(10).as_bytes().is_err());
assert_eq!(FieldValue::Len(b"hello").as_bytes().unwrap(), b"hello");
}
#[test]
fn as_u64() {
assert_eq!(FieldValue::Varint(10).as_u64().unwrap(), 10u64);
assert!(FieldValue::I32(10).as_u64().is_err());
assert!(FieldValue::Len(b"hello").as_u64().is_err());
}
}