/*
 * Copyright (C) 2009 - 2010 Funambol, Inc.
 *
 * This program is free software; you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License version 3 as published by
 * the Free Software Foundation with the addition of the following permission
 * added to Section 15 as permitted in Section 7(a): FOR ANY PART OF THE COVERED
 * WORK IN WHICH THE COPYRIGHT IS OWNED BY FUNAMBOL, FUNAMBOL DISCLAIMS THE
 * WARRANTY OF NON INFRINGEMENT OF THIRD PARTY RIGHTS.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
 * details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program; if not, see http://www.gnu.org/licenses or write to
 * the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 * MA 02110-1301 USA.
 *
 * You can contact Funambol, Inc. headquarters at 643 Bair Island Road, Suite
 * 305, Redwood City, CA 94063, USA, or at email address info@funambol.com.
 *
 * The interactive user interfaces in modified source and object code versions
 * of this program must display Appropriate Legal Notices, as required under
 * Section 5 of the GNU Affero General Public License version 3.
 *
 * In accordance with Section 7(b) of the GNU Affero General Public License
 * version 3, these Appropriate Legal Notices must retain the display of the
 * "Powered by Funambol" logo. If the display of the logo is not reasonably
 * feasible for technical reasons, the Appropriate Legal Notices must display
 * the words "Powered by Funambol".
 */

/* $Id$ */

#include "Crypto.h"

#include <stdio.h>
#include <string.h>

extern "C"
{
#include "openssl/aes.h"
  //#include "openssl/aes_ccm.h"
#include "openssl/evp.h"
#include "openssl/hmac.h"
}

namespace NS_DM_Client
{

void do_xor(unsigned char *dest, unsigned char *s1, unsigned char *s2, unsigned blen);
void ccm_hash(unsigned char *B, unsigned char *A, AES_KEY *akey);
void ccm_format_ctr(
    unsigned char *CTR, unsigned char *iv, unsigned long ivlen, unsigned int q);
void aes_ccm_inc(unsigned char *counter, unsigned q);
int Internal_AES_CCM_Common(
    unsigned char *iv,unsigned int ivlen,
    unsigned char *key,unsigned int keylen,
    unsigned char *data,unsigned long datalen,
    unsigned char *out, unsigned long *outlen,
    unsigned int taglen,int enc
);
int Internal_AES_CCM_Encrypt(
    unsigned char *iv,unsigned int ivlen,
    unsigned char *key,unsigned int keylen,
    unsigned char *data,unsigned long datalen,
    unsigned char *out, unsigned long *outlen,
    unsigned int taglen
);
int Internal_AES_CCM_Decrypt(
    unsigned char *iv,unsigned int ivlen,
    unsigned char *key,unsigned int keylen,
    unsigned char *data,unsigned long datalen,
    unsigned char *out, unsigned long *outlen,
    unsigned int taglen
);

#if defined(OPENSSL_INTERNAL)
    #define AES_CCM_ENCRYPT Internal_AES_CCM_Encrypt
    #define AES_CCM_DECRYPT Internal_AES_CCM_Decrypt
    #define AAA_DATA
#else
    #define AES_CCM_ENCRYPT ::AES_CCM_Encrypt
    #define AES_CCM_DECRYPT ::AES_CCM_Decrypt
    #define AAA_DATA NULL, 0,

#endif

//------------------------------------------------------------------------------------------------------

bool OpenSSLCrypto::HMAC_SHA256
(
    const unsigned char* inputData, size_t inputDataLength,
    const unsigned char* key, size_t keyLength,
    unsigned char* outputData, size_t& outputDataLength
)
{
    if (HMAC(EVP_sha256(), key, keyLength, inputData, inputDataLength, outputData, (unsigned int *)&outputDataLength) != 0)
    {
        return true;
    }

    outputDataLength = 0;
    return false;
}

//------------------------------------------------------------------------------------------------------

bool OpenSSLCrypto::AES_CCM_Encrypt
(
    const unsigned char* iv, size_t ivLength,
    const unsigned char* key, size_t keyLength,
    const unsigned char* inputData, size_t inputDataLength,
    unsigned char *outputData, size_t& outputDataLength
)
{
    unsigned long resLen = 0;

    bool internal = false;
    #if defined(OPENSSL_INTERNAL)
        internal = true;
    #endif

    if (AES_CCM_ENCRYPT(
        const_cast<unsigned char*>(iv), ivLength,
        const_cast<unsigned char*>(key), keyLength,
        AAA_DATA
        const_cast<unsigned char*>(inputData), inputDataLength,
        outputData, &resLen,
        8) == 1)
    {
        outputDataLength = size_t(resLen);
        return true;
    }

    outputDataLength = 0;
    return false;
}

//------------------------------------------------------------------------------------------------------

bool OpenSSLCrypto::AES_CCM_Decrypt
(
    const unsigned char* iv, size_t ivLength,
    const unsigned char* key, size_t keyLength,
    const unsigned char* inputData, size_t inputDataLength,
    unsigned char *outputData, size_t& outputDataLength
)
{
    unsigned long resLen = 0;

    bool internal = false;
    #if defined(OPENSSL_INTERNAL)
        internal = true;
    #endif

    if (AES_CCM_DECRYPT(
        const_cast<unsigned char*>(iv), ivLength,
        const_cast<unsigned char*>(key), keyLength,
        AAA_DATA
        const_cast<unsigned char*>(inputData), inputDataLength,
        outputData, &resLen,
        8) == 1)
    {
        outputDataLength = size_t(resLen);
        return true;
    }

    outputDataLength = 0;
    return false;
}

//------------------------------------------------------------------------------------------------------

int Internal_AES_CCM_Common(unsigned char *iv,unsigned int ivlen,
			  unsigned char *key,unsigned int keylen,
			  unsigned char *data,unsigned long datalen,
			  unsigned char *out, unsigned long *outlen,
			  unsigned int taglen,int enc
			  )
{
    int rv = 1;
    unsigned char A0[AES_BLOCK_SIZE];
    unsigned char B0[AES_BLOCK_SIZE];
    unsigned char CTR[AES_BLOCK_SIZE];
    unsigned char S0[AES_BLOCK_SIZE];
    unsigned t = 0;
    unsigned q = 0;
    unsigned long p = 0;
    unsigned char *oldout = out;
    AES_KEY *akey = NULL;
    int i,j;

    // check for a valid nonce length 7-13
    if(ivlen < 7 || ivlen > 13)
    {
        rv = 0;
    }
    else
    {
        /* calculate q, the bytes required to hold the possible representation
        of the LENGTH of the plaintext from the nonce */
        q = 15 - ivlen;
    }
    t = (taglen-2)/2;
    // check for a valid tag length 4,6,8,10,12,14,16
    if(0 != (taglen & 1) || (taglen < 4) || (taglen > 16))
    {
        rv = 0;
    }
    if(NULL == outlen)
    {
        rv = 0;
    }

    if(1 == rv)
    {
        *outlen = 0;
        // clear the output data length
        // now - to prevent the user from knowing why we failed always flush the output buffers
        if(enc)
        {
	        // in encrypt mode, clear the output buffer
	        memset(out, 0, (((datalen + 15) / 16) * 16) + taglen);
        }
        else
        {
	        // same in decrypt mode
	        memset(out,0,datalen-taglen);
        }

        akey = (AES_KEY*)malloc(sizeof(AES_KEY));

        if(::AES_set_encrypt_key(key, keylen * 8, akey) != 0)
        {
	        rv = 0;
        }
    }

    if (1 == rv )
    {
        // if we are decrypting, we don't decrypt the tag ...
        if(!enc)
        {
	        datalen -= taglen;
        }

        // format the first block
        memset(B0, 0, AES_BLOCK_SIZE);
        B0[0] |= (((taglen - 2) / 2) & 7) << 3;
        B0[0] |= (q - 1) & 7;

        // copy in the IV fields
        for(i = (int)1; i <= (int)ivlen; i++)
        {
	        B0[i] = iv[i-1];
        }

	    // encode the number of bytes of plaintext in the remaining space
        p = datalen;;
        for(j = 15; j >= i; j--)
        {
	        B0[j] = (p & 0xff);
	        p >>= 8;
        }

        memset(A0, 0 ,AES_BLOCK_SIZE); // start with original data == 0
        ccm_hash(B0, A0, akey);
        ccm_format_ctr(CTR, iv, ivlen, q);
        memset(A0, 0, AES_BLOCK_SIZE);

        // process the first counter block
        ::AES_encrypt(CTR, A0, akey);
        aes_ccm_inc(CTR, q);
        memcpy(S0, A0, AES_BLOCK_SIZE);

        /* now do the encrypt/decrypt phase, note that we are still adding data to the
	    hash function at this phase as well */

        for( ; datalen > 0; datalen -= AES_BLOCK_SIZE)
        {
	        int l = AES_BLOCK_SIZE;
	        if(datalen < AES_BLOCK_SIZE)
	        {
	            l = datalen;
	            // in this case we need to zero A0
	            memset(A0, 0, 16);
	        }

	        if(enc)
	        {
	            memcpy(A0, data, l);
	            ccm_hash(B0, A0, akey);
	        }

	        // prep the counter - encrypt the current value
            ::AES_encrypt(CTR, A0, akey);
	        // increment the ounter
            aes_ccm_inc(CTR, q);

	        // XOR the encrypted counter with the incoming data
	        do_xor(A0, data, A0, l);
	        // copy this to the output
	        memcpy(out, A0, l);

	        // in decrypt mode, that gave us plaintext which we hash
	        if(!enc)
	        {
	            if(l != AES_BLOCK_SIZE)
	            {
                    memset(A0 + l, 0, 16 - l);
                }
	            ccm_hash(B0, A0, akey);
	        }

	        data += l;
	        out += l;
	        *outlen += l;

	        if(datalen < AES_BLOCK_SIZE)
	        {
	            break;
	        }
        }
    }

    if(1 == rv )
    {
        // last encrypted hash block is our tag
        // XOR this with the saved first counter block
        do_xor(B0, S0, B0, taglen);
        // append this to the encrypted data in encrypt mode
        if(enc)
        {
	        memcpy(out, B0, taglen);
	        *outlen += taglen;
        }
        else
        {
	        // if it's decrypt mode, we compare it with the end of the data
	        if(memcmp(B0, data, taglen) != 0)
	        {
	            rv = 0;
	            memset(oldout, 0, *outlen);
	        }
        }
    }

    if(NULL != akey)
    {
        free(akey);
    }

    return rv;
}

//------------------------------------------------------------------------------------------------------

int Internal_AES_CCM_Encrypt(
    unsigned char *iv,unsigned int ivlen,
    unsigned char *key,unsigned int keylen,
    unsigned char *data,unsigned long datalen,
    unsigned char *out, unsigned long *outlen,
    unsigned int taglen
)
{
    int rv = 1;
    rv = Internal_AES_CCM_Common(iv, ivlen, key, keylen, data, datalen, out, outlen, taglen, 1);
    return rv;
}

//------------------------------------------------------------------------------------------------------

int Internal_AES_CCM_Decrypt(
    unsigned char *iv,unsigned int ivlen,
    unsigned char *key,unsigned int keylen,
    unsigned char *data,unsigned long datalen,
    unsigned char *out, unsigned long *outlen,
    unsigned int taglen
)
{
    int rv = 1;
    rv = Internal_AES_CCM_Common(iv, ivlen, key, keylen, data, datalen, out, outlen, taglen, 0);
    if(0 == rv)
    {
        // if anything failed, kill the output buffer
        memset(out, 0, *outlen);
        *outlen = 0;
    }
    return rv;
}

//------------------------------------------------------------------------------------------------------

void do_xor(unsigned char *dest, unsigned char *s1, unsigned char *s2, unsigned blen)
{
    unsigned int i;
    for(i = 0; i < blen; i++)
    {
	    dest[i] = s1[i] ^ s2[i];
    }
}

//------------------------------------------------------------------------------------------------------

void ccm_hash(unsigned char *B, unsigned char *A, AES_KEY *akey)
{
    do_xor(A,B,A,16);
    ::AES_encrypt(A,B,akey);
}

//------------------------------------------------------------------------------------------------------

void ccm_format_ctr(
    unsigned char *CTR,
    unsigned char *iv,
    unsigned long ivlen,
    unsigned int q
)
{
    unsigned int i;
    memset(CTR, 0, AES_BLOCK_SIZE);
    for(i = 1; i <= ivlen; i ++)
    {
        CTR[i] = iv[i-1];
    }
    CTR[0] |= ((q-1) & 7);
}

//------------------------------------------------------------------------------------------------------

void aes_ccm_inc(unsigned char *counter, unsigned q)
{
    int i;
    int cy = 0;
    for(i = 15; q > 0 ; i--, q--)
    {
        if(0xff == counter[i])
        {
            cy = 1;
        }
        counter[i]++;
        if(!cy)
        {
            break;
        }
    }
}

//------------------------------------------------------------------------------------------------------

void OpenSSLCrypto::Release()
{
    delete this;
}

//------------------------------------------------------------------------------------------------------

OpenSSLCrypto::~OpenSSLCrypto()
{

}

//------------------------------------------------------------------------------------------------------

ICrypto* CreateCryptoImpl()
{
    return new OpenSSLCrypto;
}

//------------------------------------------------------------------------------------------------------

}
