/*
 *
 *    Copyright (c) 2019 Google LLC
 *    All rights reserved.
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 */

/**
 *    @file
 *      This file defines common preprocessor defintions, constants,
 *      functions, and globals for unit and functional tests for the
 *      Inet layer.
 *
 */

#include "TestInetLayerCommon.hpp"

#include <stdlib.h>
#include <string.h>

#include <nlbyteorder.hpp>

#include <InetLayer/InetLayer.h>

#include <Weave/Support/CodeUtils.h>

#include "ToolCommon.h"

using namespace ::nl::Inet;
using namespace ::nl::Weave::System;


// Type Definitions

struct ICMPEchoHeader {
    uint8_t  mType;
    uint8_t  mCode;
    uint16_t mChecksum;
    uint16_t mID;
    uint16_t mSequenceNumber;
} __attribute__((packed));

typedef struct ICMPEchoHeader ICMPv4EchoHeader;
typedef struct ICMPEchoHeader ICMPv6EchoHeader;


// Global Variables

static const uint8_t     kICMPv4_EchoRequest   = 8;
static const uint8_t     kICMPv4_EchoReply     = 0;

static const uint8_t     kICMPv6_EchoRequest   = 128;
static const uint8_t     kICMPv6_EchoReply     = 129;

const uint8_t            gICMPv4Types[kICMPv4_FilterTypes] =
{
    kICMPv4_EchoRequest,
    kICMPv4_EchoReply
};

const uint8_t            gICMPv6Types[kICMPv6_FilterTypes] =
{
    kICMPv6_EchoRequest,
    kICMPv6_EchoReply
};

bool                     gSendIntervalExpired  = true;

uint32_t                 gSendIntervalMs       = 1000;

const char *             gInterfaceName        = NULL;

InterfaceId              gInterfaceId          = INET_NULL_INTERFACEID;

uint16_t                 gSendSize             = 59;

uint32_t                 gOptFlags             = 0;


namespace Common
{

bool IsReceiver(void)
{
    return ((gOptFlags & kOptFlagListen) == kOptFlagListen);
}

bool IsSender(void)
{
    return (!IsReceiver());
}

bool IsTesting(const TestStatus &aTestStatus)
{
    bool lStatus;

    lStatus = (!aTestStatus.mFailed && !aTestStatus.mSucceeded);

    return (lStatus);
}

bool WasSuccessful(const TestStatus &aTestStatus)
{
    bool lStatus = false;

    if (aTestStatus.mFailed)
        lStatus = false;
    else if (aTestStatus.mSucceeded)
        lStatus = true;

    return (lStatus);
}

static void FillDataBufferPattern(uint8_t *aBuffer, uint16_t aLength, uint16_t aPatternStartOffset, uint8_t aFirstValue)
{
    for (uint16_t i = aPatternStartOffset; i < aLength; i++)
    {
        const uint8_t lValue = static_cast<uint8_t>(aFirstValue & 0xFF);

        aBuffer[i] = lValue;

        aFirstValue++;
    }
}

static bool CheckDataBufferPattern(const uint8_t *aBuffer, uint16_t aLength, uint16_t aPatternStartOffset, uint8_t aFirstValue)
{
    bool lStatus = true;

    for (uint16_t i = aPatternStartOffset; i < aLength; i++)
    {
        const uint8_t lValue = aBuffer[i];

        VerifyOrExit(lValue == static_cast<uint8_t>(aFirstValue),
                     printf("Bad data value at offset %u (0x%04x): "
                            "expected 0x%02x, found 0x%02x\n",
                            i, i, aFirstValue, lValue);
                     lStatus = false;
                     DumpMemory(aBuffer + aPatternStartOffset,
                                aLength - aPatternStartOffset,
                                "0x",
                                16));

        aFirstValue++;
    }

exit:
    return (lStatus);
}

static PacketBuffer *MakeDataBuffer(uint16_t aDesiredLength, uint16_t aPatternStartOffset, uint8_t aFirstValue)
{
    PacketBuffer *  lBuffer = NULL;

    VerifyOrExit(aPatternStartOffset <= aDesiredLength, );

    lBuffer = PacketBuffer::New();
    VerifyOrExit(lBuffer != NULL, );

    aDesiredLength = min(lBuffer->MaxDataLength(), aDesiredLength);

    FillDataBufferPattern(lBuffer->Start(), aDesiredLength, aPatternStartOffset, aFirstValue);

    lBuffer->SetDataLength(aDesiredLength);

exit:
    return (lBuffer);
}

static PacketBuffer *MakeDataBuffer(uint16_t aDesiredLength, uint16_t aPatternStartOffset)
{
    const uint8_t   lFirstValue = 0;
    PacketBuffer *  lBuffer = NULL;

    lBuffer = MakeDataBuffer(aDesiredLength, aPatternStartOffset, lFirstValue);
    VerifyOrExit(lBuffer != NULL, );

exit:
    return (lBuffer);
}

template <typename tType>
static PacketBuffer *MakeICMPDataBuffer(uint16_t aDesiredUserLength, uint16_t aHeaderLength, uint16_t aPatternStartOffset, uint8_t aType)
{
    static uint16_t  lSequenceNumber = 0;
    PacketBuffer *   lBuffer = NULL;

    // To ensure there is enough room for the user data and the ICMP
    // header, include both the user data size and the ICMP header length.

    lBuffer = MakeDataBuffer(aDesiredUserLength + aHeaderLength, aPatternStartOffset);

    if (lBuffer != NULL)
    {
        tType *lHeader = reinterpret_cast<tType *>(lBuffer->Start());

        lHeader->mType           = aType;
        lHeader->mCode           = 0;
        lHeader->mChecksum       = 0;
        lHeader->mID             = rand() & UINT16_MAX;
        lHeader->mSequenceNumber = nlByteOrderSwap16HostToBig(lSequenceNumber++);
    }

    return (lBuffer);
}

PacketBuffer *MakeICMPv4DataBuffer(uint16_t aDesiredUserLength)
{
    const uint16_t  lICMPHeaderLength = sizeof (ICMPv4EchoHeader);
    const uint16_t  lPatternStartOffset = lICMPHeaderLength;
    const uint8_t   lType = gICMPv4Types[kICMP_EchoRequestIndex];
    PacketBuffer *  lBuffer = NULL;

    lBuffer = MakeICMPDataBuffer<ICMPv4EchoHeader>(aDesiredUserLength, lICMPHeaderLength, lPatternStartOffset, lType);

    return (lBuffer);
}

PacketBuffer *MakeICMPv6DataBuffer(uint16_t aDesiredUserLength)
{
    const uint16_t  lICMPHeaderLength = sizeof (ICMPv6EchoHeader);
    const uint16_t  lPatternStartOffset = lICMPHeaderLength;
    const uint8_t   lType = gICMPv6Types[kICMP_EchoRequestIndex];
    PacketBuffer *  lBuffer = NULL;

    lBuffer = MakeICMPDataBuffer<ICMPv6EchoHeader>(aDesiredUserLength, lICMPHeaderLength, lPatternStartOffset, lType);

    return (lBuffer);
}

PacketBuffer *MakeDataBuffer(uint16_t aDesiredLength, uint8_t aFirstValue)
{
    const uint16_t  lPatternStartOffset = 0;
    PacketBuffer *  lBuffer = NULL;

    lBuffer = MakeDataBuffer(aDesiredLength, lPatternStartOffset, aFirstValue);

    return (lBuffer);
}

PacketBuffer *MakeDataBuffer(uint16_t aDesiredLength)
{
    const uint16_t  lPatternStartOffset = 0;
    PacketBuffer *  lBuffer = NULL;

    lBuffer = MakeDataBuffer(aDesiredLength, lPatternStartOffset);

    return (lBuffer);
}

static bool HandleDataReceived(const PacketBuffer *aBuffer, TransferStats &aStats, bool aStatsByPacket, bool aCheckBuffer, uint16_t aPatternStartOffset, uint8_t aFirstValue)
{
    bool      lStatus = true;
    uint16_t  lTotalDataLength = 0;

    // Walk through each buffer in the packet chain, checking the
    // buffer for the expected pattern, if requested.

    for (const PacketBuffer *lBuffer = aBuffer; lBuffer != NULL; lBuffer = lBuffer->Next())
    {
        const uint16_t lDataLength = lBuffer->DataLength();

        if (aCheckBuffer)
        {
            const uint8_t *p = lBuffer->Start();

            lStatus = CheckDataBufferPattern(p, lDataLength, aPatternStartOffset, aFirstValue);
            VerifyOrExit(lStatus == true, );
        }

        lTotalDataLength += lBuffer->DataLength();
        aFirstValue += lBuffer->DataLength();
    }

    // If we are accumulating stats by packet rather than by size,
    // then increment by one (1) rather than the total buffer length.

    aStats.mReceive.mActual += ((aStatsByPacket) ? 1 : lTotalDataLength);

exit:
    return (lStatus);

}

static bool HandleICMPDataReceived(PacketBuffer *aBuffer, uint16_t aHeaderLength, TransferStats &aStats, bool aStatsByPacket, bool aCheckBuffer)
{
    const uint16_t  lPatternStartOffset = 0;
    bool            lStatus;

    aBuffer->ConsumeHead(aHeaderLength);

    lStatus = HandleDataReceived(aBuffer, aStats, aStatsByPacket, aCheckBuffer, lPatternStartOffset);

    return (lStatus);
}

bool HandleICMPv4DataReceived(PacketBuffer *aBuffer, TransferStats &aStats, bool aStatsByPacket, bool aCheckBuffer)
{
    const uint16_t  lICMPHeaderLength = sizeof (ICMPv4EchoHeader);
    bool            lStatus;

    lStatus = HandleICMPDataReceived(aBuffer, lICMPHeaderLength, aStats, aStatsByPacket, aCheckBuffer);

    return (lStatus);
}

bool HandleICMPv6DataReceived(PacketBuffer *aBuffer, TransferStats &aStats, bool aStatsByPacket, bool aCheckBuffer)
{
    const uint16_t  lICMPHeaderLength = sizeof (ICMPv6EchoHeader);
    bool            lStatus;

    lStatus = HandleICMPDataReceived(aBuffer, lICMPHeaderLength, aStats, aStatsByPacket, aCheckBuffer);

    return (lStatus);
}

bool HandleDataReceived(const PacketBuffer *aBuffer, TransferStats &aStats, bool aStatsByPacket, bool aCheckBuffer, uint8_t aFirstValue)
{
    const uint16_t  lPatternStartOffset = 0;
    bool            lStatus;

    lStatus = HandleDataReceived(aBuffer, aStats, aStatsByPacket, aCheckBuffer, lPatternStartOffset, aFirstValue);

    return (lStatus);
}


bool HandleDataReceived(const PacketBuffer *aBuffer, TransferStats &aStats, bool aStatsByPacket, bool aCheckBuffer)
{
    const uint8_t   lFirstValue = 0;
    const uint16_t  lPatternStartOffset = 0;
    bool            lStatus;

    lStatus = HandleDataReceived(aBuffer, aStats, aStatsByPacket, aCheckBuffer, lPatternStartOffset, lFirstValue);

    return (lStatus);
}

bool HandleUDPDataReceived(const PacketBuffer *aBuffer, TransferStats &aStats, bool aStatsByPacket, bool aCheckBuffer)
{
    bool            lStatus;

    lStatus = HandleDataReceived(aBuffer, aStats, aStatsByPacket, aCheckBuffer);

    return (lStatus);
}

bool HandleTCPDataReceived(const PacketBuffer *aBuffer, TransferStats &aStats, bool aStatsByPacket, bool aCheckBuffer)
{
    bool            lStatus;

    lStatus = HandleDataReceived(aBuffer, aStats, aStatsByPacket, aCheckBuffer);

    return (lStatus);
}

// Timer Callback Handler

void HandleSendTimerComplete(System::Layer *aSystemLayer, void *aAppState, System::Error aError)
{
    FAIL_ERROR(aError, "Send timer completed with error");

    gSendIntervalExpired = true;

    DriveSend();
}

// Raw Endpoint Callback Handlers

void HandleRawMessageReceived(const IPEndPointBasis *aEndPoint, const PacketBuffer *aBuffer, const IPPacketInfo *aPacketInfo)
{
    char  lSourceAddressBuffer[INET6_ADDRSTRLEN];
    char  lDestinationAddressBuffer[INET6_ADDRSTRLEN];

    aPacketInfo->SrcAddress.ToString(lSourceAddressBuffer, sizeof (lSourceAddressBuffer));
    aPacketInfo->DestAddress.ToString(lDestinationAddressBuffer, sizeof (lDestinationAddressBuffer));

    printf("Raw message received from %s to %s (%zu bytes)\n",
           lSourceAddressBuffer,
           lDestinationAddressBuffer,
           static_cast<size_t>(aBuffer->DataLength()));
}

void HandleRawReceiveError(const IPEndPointBasis *aEndPoint, const INET_ERROR &aError, const IPPacketInfo *aPacketInfo)
{
    char     lAddressBuffer[INET6_ADDRSTRLEN];

    if (aPacketInfo != NULL)
    {
        aPacketInfo->SrcAddress.ToString(lAddressBuffer, sizeof (lAddressBuffer));
    }
    else
    {
        strcpy(lAddressBuffer, "(unknown)");
    }

    printf("IP receive error from %s %s\n", lAddressBuffer, ErrorStr(aError));
}

// UDP Endpoint Callback Handlers

void HandleUDPMessageReceived(const IPEndPointBasis *aEndPoint, const PacketBuffer *aBuffer, const IPPacketInfo *aPacketInfo)
{
    char  lSourceAddressBuffer[INET6_ADDRSTRLEN];
    char  lDestinationAddressBuffer[INET6_ADDRSTRLEN];

    aPacketInfo->SrcAddress.ToString(lSourceAddressBuffer, sizeof (lSourceAddressBuffer));
    aPacketInfo->DestAddress.ToString(lDestinationAddressBuffer, sizeof (lDestinationAddressBuffer));

    printf("UDP packet received from %s:%u to %s:%u (%zu bytes)\n",
           lSourceAddressBuffer, aPacketInfo->SrcPort,
           lDestinationAddressBuffer, aPacketInfo->DestPort,
           static_cast<size_t>(aBuffer->DataLength()));
}

void HandleUDPReceiveError(const IPEndPointBasis *aEndPoint, const INET_ERROR &aError, const IPPacketInfo *aPacketInfo)
{
    char     lAddressBuffer[INET6_ADDRSTRLEN];
    uint16_t lSourcePort;

    if (aPacketInfo != NULL)
    {
        aPacketInfo->SrcAddress.ToString(lAddressBuffer, sizeof (lAddressBuffer));
        lSourcePort = aPacketInfo->SrcPort;
    }
    else
    {
        strcpy(lAddressBuffer, "(unknown)");
        lSourcePort = 0;
    }

    printf("UDP receive error from %s:%u: %s\n", lAddressBuffer, lSourcePort, ErrorStr(aError));
}

}; // namespace Common
