/* Copyright (c) 2023, Google Inc.
 *
 * Permission to use, copy, modify, and/or distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 */
use crate::{
    digest::{Md, Sha256, Sha512},
    CSlice, ForeignTypeRef as _, PanicResultHandler,
};
use core::{
    ffi::{c_uint, c_void},
    marker::PhantomData,
    ptr,
};

/// Computes the HMAC-SHA-256 of `data` as a one-shot operation.
///
/// Calculates the HMAC of data, using the given `key` and returns the result.
/// It returns the computed hmac.
/// Can panic if memory allocation fails in the underlying BoringSSL code.
pub fn hmac_sha_256(key: &[u8], data: &[u8]) -> [u8; 32] {
    hmac::<32, Sha256>(key, data)
}

/// Computes the HMAC-SHA-512 of `data` as a one-shot operation.
///
/// Calculates the HMAC of data, using the given `key` and returns the result.
/// It returns the computed hmac.
/// Can panic if memory allocation fails in the underlying BoringSSL code.
pub fn hmac_sha_512(key: &[u8], data: &[u8]) -> [u8; 64] {
    hmac::<64, Sha512>(key, data)
}

/// The BoringSSL HMAC-SHA-256 implementation. The operations may panic if memory allocation fails
/// in BoringSSL.
pub struct HmacSha256(Hmac<32, Sha256>);

impl HmacSha256 {
    /// Create a new hmac from a fixed size key.
    pub fn new(key: [u8; 32]) -> Self {
        Self(Hmac::new(key))
    }

    /// Create new hmac value from variable size key.
    pub fn new_from_slice(key: &[u8]) -> Self {
        Self(Hmac::new_from_slice(key))
    }

    /// Update state using the provided data.
    pub fn update(&mut self, data: &[u8]) {
        self.0.update(data)
    }

    /// Obtain the hmac computation consuming the hmac instance.
    pub fn finalize(self) -> [u8; 32] {
        self.0.finalize()
    }

    /// Check that the tag value is correct for the processed input.
    pub fn verify_slice(self, tag: &[u8]) -> Result<(), MacError> {
        self.0.verify_slice(tag)
    }

    /// Check that the tag value is correct for the processed input.
    pub fn verify(self, tag: [u8; 32]) -> Result<(), MacError> {
        self.0.verify(tag)
    }

    /// Check truncated tag correctness using left side bytes of the calculated tag.
    pub fn verify_truncated_left(self, tag: &[u8]) -> Result<(), MacError> {
        self.0.verify_truncated_left(tag)
    }
}

/// The BoringSSL HMAC-SHA-512 implementation. The operations may panic if memory allocation fails
/// in BoringSSL.
pub struct HmacSha512(Hmac<64, Sha512>);

impl HmacSha512 {
    /// Create a new hmac from a fixed size key.
    pub fn new(key: [u8; 64]) -> Self {
        Self(Hmac::new(key))
    }

    /// Create new hmac value from variable size key.
    pub fn new_from_slice(key: &[u8]) -> Self {
        Self(Hmac::new_from_slice(key))
    }

    /// Update state using the provided data.
    pub fn update(&mut self, data: &[u8]) {
        self.0.update(data)
    }

    /// Obtain the hmac computation consuming the hmac instance.
    pub fn finalize(self) -> [u8; 64] {
        self.0.finalize()
    }

    /// Check that the tag value is correct for the processed input.
    pub fn verify_slice(self, tag: &[u8]) -> Result<(), MacError> {
        self.0.verify_slice(tag)
    }

    /// Check that the tag value is correct for the processed input.
    pub fn verify(self, tag: [u8; 64]) -> Result<(), MacError> {
        self.0.verify(tag)
    }

    /// Check truncated tag correctness using left side bytes of the calculated tag.
    pub fn verify_truncated_left(self, tag: &[u8]) -> Result<(), MacError> {
        self.0.verify_truncated_left(tag)
    }
}

/// Error type for when the output of the hmac operation is not equal to the expected value.
#[derive(Debug)]
pub struct MacError;

/// Private generically implemented function for computing hmac as a oneshot operation.
/// This should only be exposed publicly by types with the correct output size `N` which corresponds
/// to the output size of the provided generic hash function. Ideally `N` would just come from `M`,
/// but this is not possible until the Rust language can support the `min_const_generics` feature.
/// Until then we will have to pass both separately: https://github.com/rust-lang/rust/issues/60551
#[inline]
fn hmac<const N: usize, M: Md>(key: &[u8], data: &[u8]) -> [u8; N] {
    let mut out = [0_u8; N];
    let mut size: c_uint = 0;

    // Safety:
    // - buf always contains N bytes of space
    // - If NULL is returned on error we panic immediately
    unsafe {
        bssl_sys::HMAC(
            M::get_md().as_ptr(),
            CSlice::from(key).as_ptr(),
            key.len(),
            CSlice::from(data).as_ptr(),
            data.len(),
            out.as_mut_ptr(),
            &mut size as *mut c_uint,
        )
    }
    .panic_if_error();

    out
}

/// Private generically implemented hmac  instance given a generic hash function and a length `N`,
/// where `N` is the output size of the hash function. This should only be exposed publicly by
/// wrapper types with the correct output size `N` which corresponds to the output size of the
/// provided generic hash function. Ideally `N` would just come from `M`, but this is not possible
/// until the Rust language can support the `min_const_generics` feature. Until then we will have to
/// pass both separately: https://github.com/rust-lang/rust/issues/60551
struct Hmac<const N: usize, M: Md> {
    ctx: *mut bssl_sys::HMAC_CTX,
    _marker: PhantomData<M>,
}

impl<const N: usize, M: Md> Hmac<N, M> {
    /// Infallible HMAC creation from a fixed length key.
    fn new(key: [u8; N]) -> Self {
        Self::new_from_slice(&key)
    }

    /// Create new hmac value from variable size key. Panics on allocation failure
    fn new_from_slice(key: &[u8]) -> Self {
        // Safety:
        // - HMAC_CTX_new panics if allocation fails
        let ctx = unsafe { bssl_sys::HMAC_CTX_new() };
        ctx.panic_if_error();

        // Safety:
        // - HMAC_Init_ex must be called with a context previously created with HMAC_CTX_new,
        //   which is the line above.
        // - HMAC_Init_ex may return an error if key is null but the md is different from
        //   before. This is avoided here since key is guaranteed to be non-null.
        // - HMAC_Init_ex returns 0 on allocation failure in which case we panic
        unsafe {
            bssl_sys::HMAC_Init_ex(
                ctx,
                CSlice::from(key).as_ptr() as *const c_void,
                key.len(),
                M::get_md().as_ptr(),
                ptr::null_mut(),
            )
        }
        .panic_if_error();

        Self {
            ctx,
            _marker: Default::default(),
        }
    }

    /// Update state using the provided data, can be called repeatedly.
    fn update(&mut self, data: &[u8]) {
        unsafe {
            // Safety: HMAC_Update will always return 1, in case it doesnt we panic
            bssl_sys::HMAC_Update(self.ctx, data.as_ptr(), data.len())
        }
        .panic_if_error()
    }

    /// Obtain the hmac computation consuming the hmac instance.
    fn finalize(self) -> [u8; N] {
        let mut buf = [0_u8; N];
        let mut size: c_uint = 0;
        // Safety:
        // - hmac has a fixed size output of N which will never exceed the length of an N
        // length array
        // - on allocation failure we panic
        unsafe { bssl_sys::HMAC_Final(self.ctx, buf.as_mut_ptr(), &mut size as *mut c_uint) }
            .panic_if_error();
        buf
    }

    /// Check that the tag value is correct for the processed input.
    fn verify(self, tag: [u8; N]) -> Result<(), MacError> {
        self.verify_slice(&tag)
    }

    /// Check truncated tag correctness using all bytes
    /// of calculated tag.
    ///
    /// Returns `Error` if `tag` is not valid or not equal in length
    /// to MAC's output.
    fn verify_slice(self, tag: &[u8]) -> Result<(), MacError> {
        tag.len().eq(&N).then_some(()).ok_or(MacError)?;
        self.verify_truncated_left(tag)
    }

    /// Check truncated tag correctness using left side bytes
    /// (i.e. `tag[..n]`) of calculated tag.
    ///
    /// Returns `Error` if `tag` is not valid or empty.
    fn verify_truncated_left(self, tag: &[u8]) -> Result<(), MacError> {
        let len = tag.len();
        if len == 0 || len > N {
            return Err(MacError);
        }

        let result = &self.finalize()[..len];

        // Safety:
        // - if a != b is undefined, it simply returns a non-zero result
        unsafe {
            bssl_sys::CRYPTO_memcmp(
                CSlice::from(result).as_ptr() as *const c_void,
                CSlice::from(tag).as_ptr() as *const c_void,
                result.len(),
            )
        }
        .eq(&0)
        .then_some(())
        .ok_or(MacError)
    }
}

impl<const N: usize, M: Md> Drop for Hmac<N, M> {
    fn drop(&mut self) {
        unsafe { bssl_sys::HMAC_CTX_free(self.ctx) }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn hmac_sha256_test() {
        let expected_hmac = [
            0xb0, 0x34, 0x4c, 0x61, 0xd8, 0xdb, 0x38, 0x53, 0x5c, 0xa8, 0xaf, 0xce, 0xaf, 0xb,
            0xf1, 0x2b, 0x88, 0x1d, 0xc2, 0x0, 0xc9, 0x83, 0x3d, 0xa7, 0x26, 0xe9, 0x37, 0x6c,
            0x2e, 0x32, 0xcf, 0xf7,
        ];

        let key: [u8; 20] = [0x0b; 20];
        let data = b"Hi There";

        let mut hmac = HmacSha256::new_from_slice(&key);
        hmac.update(data);
        let hmac_result: [u8; 32] = hmac.finalize();

        // let hmac_result =
        //     hmac(Md::sha256(), &key, data, &mut out).expect("Couldn't calculate sha256 hmac");
        assert_eq!(&hmac_result, &expected_hmac);
    }

    #[test]
    fn hmac_sha256_fixed_size_key_test() {
        let expected_hmac = [
            0x19, 0x8a, 0x60, 0x7e, 0xb4, 0x4b, 0xfb, 0xc6, 0x99, 0x3, 0xa0, 0xf1, 0xcf, 0x2b,
            0xbd, 0xc5, 0xba, 0xa, 0xa3, 0xf3, 0xd9, 0xae, 0x3c, 0x1c, 0x7a, 0x3b, 0x16, 0x96,
            0xa0, 0xb6, 0x8c, 0xf7,
        ];

        let key: [u8; 32] = [0x0b; 32];
        let data = b"Hi There";

        let mut hmac = HmacSha256::new(key);
        hmac.update(data);
        let hmac_result: [u8; 32] = hmac.finalize();
        assert_eq!(&hmac_result, &expected_hmac);
    }

    #[test]
    fn hmac_sha256_update_test() {
        let expected_hmac = [
            0xb0, 0x34, 0x4c, 0x61, 0xd8, 0xdb, 0x38, 0x53, 0x5c, 0xa8, 0xaf, 0xce, 0xaf, 0xb,
            0xf1, 0x2b, 0x88, 0x1d, 0xc2, 0x0, 0xc9, 0x83, 0x3d, 0xa7, 0x26, 0xe9, 0x37, 0x6c,
            0x2e, 0x32, 0xcf, 0xf7,
        ];
        let key: [u8; 20] = [0x0b; 20];
        let data = b"Hi There";
        let mut hmac: HmacSha256 = HmacSha256::new_from_slice(&key);
        hmac.update(data);
        let result = hmac.finalize();
        assert_eq!(&result, &expected_hmac);
        assert_eq!(result.len(), 32);
    }

    #[test]
    fn hmac_sha256_test_big_buffer() {
        let expected_hmac = [
            0xb0, 0x34, 0x4c, 0x61, 0xd8, 0xdb, 0x38, 0x53, 0x5c, 0xa8, 0xaf, 0xce, 0xaf, 0xb,
            0xf1, 0x2b, 0x88, 0x1d, 0xc2, 0x0, 0xc9, 0x83, 0x3d, 0xa7, 0x26, 0xe9, 0x37, 0x6c,
            0x2e, 0x32, 0xcf, 0xf7,
        ];
        let key: [u8; 20] = [0x0b; 20];
        let data = b"Hi There";
        let hmac_result = hmac_sha_256(&key, data);
        assert_eq!(&hmac_result, &expected_hmac);
    }

    #[test]
    fn hmac_sha256_update_chunks_test() {
        let expected_hmac = [
            0xb0, 0x34, 0x4c, 0x61, 0xd8, 0xdb, 0x38, 0x53, 0x5c, 0xa8, 0xaf, 0xce, 0xaf, 0xb,
            0xf1, 0x2b, 0x88, 0x1d, 0xc2, 0x0, 0xc9, 0x83, 0x3d, 0xa7, 0x26, 0xe9, 0x37, 0x6c,
            0x2e, 0x32, 0xcf, 0xf7,
        ];
        let key: [u8; 20] = [0x0b; 20];
        let mut hmac = HmacSha256::new_from_slice(&key);
        hmac.update(b"Hi");
        hmac.update(b" There");
        let result = hmac.finalize();
        assert_eq!(&result, &expected_hmac);
    }

    #[test]
    fn hmac_sha256_verify_test() {
        let expected_hmac = [
            0xb0, 0x34, 0x4c, 0x61, 0xd8, 0xdb, 0x38, 0x53, 0x5c, 0xa8, 0xaf, 0xce, 0xaf, 0xb,
            0xf1, 0x2b, 0x88, 0x1d, 0xc2, 0x0, 0xc9, 0x83, 0x3d, 0xa7, 0x26, 0xe9, 0x37, 0x6c,
            0x2e, 0x32, 0xcf, 0xf7,
        ];
        let key: [u8; 20] = [0x0b; 20];
        let data = b"Hi There";
        let mut hmac: HmacSha256 = HmacSha256::new_from_slice(&key);
        hmac.update(data);
        assert!(hmac.verify(expected_hmac).is_ok())
    }
}
