ํ•ด๋‹ต

use std::convert::TryFrom;
use thiserror::Error;

#[derive(Debug, Error)]
enum Error {
    #[error("์ž˜๋ชป๋œ varint")]
    InvalidVarint,
    #[error("์ž˜๋ชป๋œ wire-type")]
    InvalidWireType,
    #[error("์˜ˆ์ƒ์น˜ ๋ชปํ•œ EOF")]
    UnexpectedEOF,
    #[error("์ž˜๋ชป๋œ ๊ธธ์ด")]
    InvalidSize(#[from] std::num::TryFromIntError),
    #[error("์˜ˆ์ƒ์น˜ ๋ชปํ•œ wire-type")]
    UnexpectedWireType,
    #[error("์ž˜๋ชป๋œ ๋ฌธ์ž์—ด(UTF-8 ์•„๋‹˜)")]
    InvalidString,
}

/// ์™€์ด์–ด์— ํ‘œ์‹œ๋œ ์™€์ด์–ด ํƒ€์ž…์ž…๋‹ˆ๋‹ค.
enum WireType {
    /// Varint WireType์€ ๊ฐ’์ด ๋‹จ์ผ VARINT์ž„์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.
    Varint,
    //I64,  -- not needed for this exercise
    /// The Len WireType indicates that the value is a length represented as a
    /// VARINT followed by exactly that number of bytes.
    Len,
    /// The I32 WireType indicates that the value is precisely 4 bytes in
    /// little-endian order containing a 32-bit signed integer.
    I32,
}

#[derive(Debug)]
/// ์™€์ด์–ด ํƒ€์ž…์— ๋”ฐ๋ผ ํƒ€์ž…์ด ์ง€์ •๋œ ํ•„๋“œ ๊ฐ’์ž…๋‹ˆ๋‹ค.
enum FieldValue<'a> {
    Varint(u64),
    //I64(i64),  -- ์ด ์—ฐ์Šต์—์„œ๋Š” ํ•„์š”ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
    Len(&'a [u8]),
    I32(i32),
}

#[derive(Debug)]
/// ํ•„๋“œ ๋ฒˆํ˜ธ ๋ฐ ๊ฐ’์„ ํฌํ•จํ•˜๋Š” ํ•„๋“œ์ž…๋‹ˆ๋‹ค.
struct Field<'a> {
    field_num: u64,
    value: FieldValue<'a>,
}

trait ProtoMessage<'a>: Default + 'a {
    fn add_field(&mut self, field: Field<'a>) -> Result<(), Error>;
}

impl TryFrom<u64> for WireType {
    type Error = Error;

    fn try_from(value: u64) -> Result<WireType, Error> {
        Ok(match value {
            0 => WireType::Varint,
            //1 => WireType::I64,  -- ์ด ์—ฐ์Šต์—์„œ๋Š” ํ•„์š”ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
            2 => WireType::Len,
            5 => WireType::I32,
            _ => return Err(Error::InvalidWireType),
        })
    }
}

impl<'a> FieldValue<'a> {
    fn as_string(&self) -> Result<&'a str, Error> {
        let FieldValue::Len(data) = self else {
            return Err(Error::UnexpectedWireType);
        };
        std::str::from_utf8(data).map_err(|_| Error::InvalidString)
    }

    fn as_bytes(&self) -> Result<&'a [u8], Error> {
        let FieldValue::Len(data) = self else {
            return Err(Error::UnexpectedWireType);
        };
        Ok(data)
    }

    fn as_u64(&self) -> Result<u64, Error> {
        let FieldValue::Varint(value) = self else {
            return Err(Error::UnexpectedWireType);
        };
        Ok(*value)
    }
}

/// VARINT๋ฅผ ํŒŒ์‹ฑํ•˜์—ฌ ํŒŒ์‹ฑ๋œ ๊ฐ’๊ณผ ๋‚˜๋จธ์ง€ ๋ฐ”์ดํŠธ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
fn parse_varint(data: &[u8]) -> Result<(u64, &[u8]), Error> {
    for i in 0..7 {
        let Some(b) = data.get(i) else {
            return Err(Error::InvalidVarint);
        };
        if b & 0x80 == 0 {
            // ์ด๋Š” VARINT์˜ ๋งˆ์ง€๋ง‰ ๋ฐ”์ดํŠธ์ด๋ฏ€๋กœ
            // u64๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
            let mut value = 0u64;
            for b in data[..=i].iter().rev() {
                value = (value << 7) | (b & 0x7f) as u64;
            }
            return Ok((value, &data[i + 1..]));
        }
    }

    // 7๋ฐ”์ดํŠธ๋ฅผ ์ดˆ๊ณผํ•˜๋ฉด ์œ ํšจํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
    Err(Error::InvalidVarint)
}

/// ํƒœ๊ทธ๋ฅผ ํ•„๋“œ ๋ฒˆํ˜ธ์™€ WireType์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
fn unpack_tag(tag: u64) -> Result<(u64, WireType), Error> {
    let field_num = tag >> 3;
    let wire_type = WireType::try_from(tag & 0x7)?;
    Ok((field_num, wire_type))
}

/// ํ•„๋“œ๋ฅผ ํŒŒ์‹ฑํ•˜์—ฌ ๋‚˜๋จธ์ง€ ๋ฐ”์ดํŠธ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
fn parse_field(data: &[u8]) -> Result<(Field, &[u8]), Error> {
    let (tag, remainder) = parse_varint(data)?;
    let (field_num, wire_type) = unpack_tag(tag)?;
    let (fieldvalue, remainder) = match wire_type {
        WireType::Varint => {
            let (value, remainder) = parse_varint(remainder)?;
            (FieldValue::Varint(value), remainder)
        }
        WireType::Len => {
            let (len, remainder) = parse_varint(remainder)?;
            let len: usize = len.try_into()?;
            if remainder.len() < len {
                return Err(Error::UnexpectedEOF);
            }
            let (value, remainder) = remainder.split_at(len);
            (FieldValue::Len(value), remainder)
        }
        WireType::I32 => {
            if remainder.len() < 4 {
                return Err(Error::UnexpectedEOF);
            }
            let (value, remainder) = remainder.split_at(4);
            // `value`์˜ ๊ธธ์ด๊ฐ€ 4๋ฐ”์ดํŠธ์ด๋ฏ€๋กœ ์˜ค๋ฅ˜๋ฅผ ๋ž˜ํ•‘ ํ•ด์ œํ•ฉ๋‹ˆ๋‹ค.
            let value = i32::from_le_bytes(value.try_into().unwrap());
            (FieldValue::I32(value), remainder)
        }
    };
    Ok((Field { field_num, value: fieldvalue }, remainder))
}

/// ์ฃผ์–ด์ง„ ๋ฐ์ดํ„ฐ์—์„œ ๋ฉ”์‹œ์ง€๋ฅผ ํŒŒ์‹ฑํ•˜์—ฌ ๋ฉ”์‹œ์ง€์˜ ๊ฐ ํ•„๋“œ์— ๋Œ€ํ•ด `T::add_field`๋ฅผ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค.
///
/// ์ „์ฒด ์ž…๋ ฅ์ด ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
fn parse_message<'a, T: ProtoMessage<'a>>(mut data: &'a [u8]) -> Result<T, Error> {
    let mut result = T::default();
    while !data.is_empty() {
        let parsed = parse_field(data)?;
        result.add_field(parsed.0)?;
        data = parsed.1;
    }
    Ok(result)
}

#[derive(Debug, Default)]
struct PhoneNumber<'a> {
    number: &'a str,
    type_: &'a str,
}

#[derive(Debug, Default)]
struct Person<'a> {
    name: &'a str,
    id: u64,
    phone: Vec<PhoneNumber<'a>>,
}

impl<'a> ProtoMessage<'a> for Person<'a> {
    fn add_field(&mut self, field: Field<'a>) -> Result<(), Error> {
        match field.field_num {
            1 => self.name = field.value.as_string()?,
            2 => self.id = field.value.as_u64()?,
            3 => self.phone.push(parse_message(field.value.as_bytes()?)?),
            _ => {} // ๋‚˜๋จธ์ง€๋Š” ๋ชจ๋‘ ๊ฑด๋„ˆ๋œ๋‹ˆ๋‹ค.
        }
        Ok(())
    }
}

impl<'a> ProtoMessage<'a> for PhoneNumber<'a> {
    fn add_field(&mut self, field: Field<'a>) -> Result<(), Error> {
        match field.field_num {
            1 => self.number = field.value.as_string()?,
            2 => self.type_ = field.value.as_string()?,
            _ => {} // ๋‚˜๋จธ์ง€๋Š” ๋ชจ๋‘ ๊ฑด๋„ˆ๋œ๋‹ˆ๋‹ค.
        }
        Ok(())
    }
}

fn main() {
    let person: Person = parse_message(&[
        0x0a, 0x07, 0x6d, 0x61, 0x78, 0x77, 0x65, 0x6c, 0x6c, 0x10, 0x2a, 0x1a,
        0x16, 0x0a, 0x0e, 0x2b, 0x31, 0x32, 0x30, 0x32, 0x2d, 0x35, 0x35, 0x35,
        0x2d, 0x31, 0x32, 0x31, 0x32, 0x12, 0x04, 0x68, 0x6f, 0x6d, 0x65, 0x1a,
        0x18, 0x0a, 0x0e, 0x2b, 0x31, 0x38, 0x30, 0x30, 0x2d, 0x38, 0x36, 0x37,
        0x2d, 0x35, 0x33, 0x30, 0x38, 0x12, 0x06, 0x6d, 0x6f, 0x62, 0x69, 0x6c,
        0x65,
    ])
    .unwrap();
    println!("{:#?}", person);
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn as_string() {
        assert!(FieldValue::Varint(10).as_string().is_err());
        assert!(FieldValue::I32(10).as_string().is_err());
        assert_eq!(FieldValue::Len(b"hello").as_string().unwrap(), "hello");
    }

    #[test]
    fn as_bytes() {
        assert!(FieldValue::Varint(10).as_bytes().is_err());
        assert!(FieldValue::I32(10).as_bytes().is_err());
        assert_eq!(FieldValue::Len(b"hello").as_bytes().unwrap(), b"hello");
    }

    #[test]
    fn as_u64() {
        assert_eq!(FieldValue::Varint(10).as_u64().unwrap(), 10u64);
        assert!(FieldValue::I32(10).as_u64().is_err());
        assert!(FieldValue::Len(b"hello").as_u64().is_err());
    }
}