/* fromwire.c
 *
 * Copyright (c) 2019 Apple Computer, Inc. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 * DNS wire-format utility functions.
 *
 * Functions that are neither necessary for very simple DNS packet generation, nor required for parsing
 * a message, e.g. compression, name printing, etc.
 */

#include <stdio.h>
#include <unistd.h>
#include <string.h>
#include <sys/errno.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <stdlib.h>
#include <ctype.h>
#include "srp.h"
#include "dns-msg.h"

#include "mDNSEmbeddedAPI.h"
#include "DNSCommon.h"

#undef LogMsg
#define LogMsg(...)

// We need the compression routines from DNSCommon.c, but we can't link to it because that
// pulls in a _lot_ of stuff we don't want.   The solution?   Define STANDALONE (this is done
// in the Mkaefile, and include DNSCommon.c.
//
// The only functions that aren't excluded by STANDALONE are FindCompressionPointer and
// putDomainNameAsLabels.

#include "DNSCommon.c"

void
dns_concatenate_name_to_wire(dns_towire_state_t *towire, dns_name_t *labels_prefix, const char *prefix, const char *suffix)
{
    dns_wire_t namebuf;
    dns_towire_state_t namewire;
    mDNSu8 *ret;

    memset(&namewire, 0, sizeof namewire);
    namewire.message = &namebuf;
    namewire.lim = &namebuf.data[DNS_DATA_SIZE];
    namewire.p = namebuf.data;
    if (prefix != NULL) {
        dns_name_to_wire(NULL, &namewire, prefix);
    } else if (labels_prefix != NULL) {
        int bytes_written;

        if (!towire->error) {
            bytes_written = dns_name_to_wire_canonical(namewire.p, namewire.lim - namewire.p, labels_prefix);
            // This can never occur with a valid name.
            if (bytes_written == 0) {
                namewire.truncated = true;
            } else {
                namewire.p += bytes_written;
            }
        }
    }
    if (suffix != NULL) {
        dns_full_name_to_wire(NULL, &namewire, suffix);
    }

    ret = putDomainNameAsLabels((DNSMessage *)towire->message, towire->p, towire->lim, (domainname *)namebuf.data);
    if (ret == NULL) {
        towire->error = ENOBUFS;
        towire->truncated = true;
        return;
    }

    // Shouldn't happen
    if (ret > towire->lim) {
        towire->error = ENOBUFS;
        towire->truncated = true;
    } else {
        towire->p = ret;
    }
}

const char *NONNULL
dns_name_print(dns_name_t *NONNULL name, char *buf, int bufmax)
{
    dns_label_t *lp;
    int ix = 0;
    int i;

    // Copy the labels in one at a time, putting a dot between each one; if there isn't room
    // in the buffer (shouldn't be the case), copy as much as will fit, leaving room for a NUL
    // termination.
    for (lp = name; lp; lp = lp->next) {
        if (ix != 0) {
            if (ix + 2 >= bufmax) {
                break;
            }
            buf[ix++] = '.';
        }
        for (i = 0; i < lp->len; i++) {
            if (isascii(lp->data[i]) && isprint(lp->data[i])) {
                if (ix + 2 >= bufmax) {
                    break;
                }
                buf[ix++] = lp->data[i];
            } else {
                if (ix + 5 >= bufmax) {
                    break;
                }
                buf[ix++] = '\\';
                buf[ix++] = '0' + (lp->data[i] / 100);
                buf[ix++] = '0' + (lp->data[i] /  10) % 10;
                buf[ix++] = '0' + lp->data[i]         % 10;
            }
        }
        if (i != lp->len) {
            break;
        }
    }
    buf[ix++] = 0;
    return buf;
}

bool
dns_labels_equal(const char *label1, const char *label2, size_t len)
{
    int i;
    for (i = 0; i < len; i++) {
        if (isascii(label1[i]) && isascii(label2[i])) {
            if (tolower(label1[i]) != tolower(label2[i])) {
                return false;
            }
        }
        else {
            if (label1[i] != label2[i]) {
                return false;
            }
        }
    }
    return true;
}

bool
dns_names_equal(dns_label_t *NONNULL name1, dns_label_t *NONNULL name2)
{
    if (name1->len != name2->len) {
        return false;
    }
    if (name1->len != 0 && !dns_labels_equal(name1->data, name2->data, name1->len) != 0) {
        return false;
    }
    if (name1->next != NULL && name2->next != NULL) {
        return dns_names_equal(name1->next, name2->next);
    }
    if (name1->next == NULL && name2->next == NULL) {
        return true;
    }
    return false;
}

// Note that "foo.arpa" is not the same as "foo.arpa."
bool
dns_names_equal_text(dns_label_t *NONNULL name1, const char *NONNULL name2)
{
    const char *ndot;
    const char *s, *t;
    int tlen = 0;
    ndot = strchr(name2, '.');
    if (ndot == NULL) {
        ndot = name2 + strlen(name2);
    }
    for (s = name2; s < ndot; s++) {
        if (*s == '\\') {
            if (s + 4 <= ndot) {
                tlen++;
                s += 3;
            } else {
                return false;  // An invalid name can't be equal to anything.
            }
        } else {
            tlen++;
        }
    }
    if (name1->len != tlen) {
        return false;
    }
    if (name1->len != 0) {
        t = name1->data;
        for (s = name2; s < ndot; s++, t++) {
            if (*s == '\\') { // already bounds checked
                int v0 = s[1] - '0';
                int v1 = s[2] - '0';
                int v2 = s[3] - '0';
                int val = v0 * 100 + v1 * 10 + v2;
                if (val > 255) {
                    return false;
                } else if (isascii(*s) && isascii(*t)) {
                    if (tolower(*s) != tolower(*t)) {
                        return false;
                    }
                } else if (val != *t) {
                    return false;
                }
                s += 3;
            } else {
                if (*s != *t) {
                    return false;
                }
            }
        }
    }
    if (name1->next != NULL && *ndot == '.') {
        return dns_names_equal_text(name1->next, ndot + 1);
    }
    if (name1->next == NULL && *ndot == 0) {
        return true;
    }
    return false;
}

// Find the length of a name in uncompressed wire format.
// This is in fromwire because we use it for validating signatures, and don't need it for
// sending.
static size_t
dns_name_wire_length_in(dns_label_t *NONNULL name, size_t ret)
{
    // Root label.
    if (name == NULL)
        return ret;
    return dns_name_wire_length_in(name->next, ret + name->len + 1);
}

size_t
dns_name_wire_length(dns_label_t *NONNULL name)
{
    return dns_name_wire_length_in(name, 0);
}

// Copy a name we've parsed from a message out in canonical wire format so that we can
// use it to verify a signature.   As above, not actually needed for copying to a message
// we're going to send, since in that case we want to try to compress.
static size_t
dns_name_to_wire_canonical_in(uint8_t *NONNULL buf, size_t max, size_t ret, dns_label_t *NONNULL name)
{
    if (name == NULL) {
        return ret;
    }
    if (max < name->len + 1) {
        return 0;
    }
    *buf = name->len;
    memcpy(buf + 1, name->data, name->len);
    return dns_name_to_wire_canonical_in(buf + name->len + 1,
                                         max - name->len - 1, ret + name->len + 1, name->next);
}

size_t
dns_name_to_wire_canonical(uint8_t *NONNULL buf, size_t max, dns_label_t *NONNULL name)
{
    return dns_name_to_wire_canonical_in(buf, max, 0, name);
}
    
// Local Variables:
// mode: C
// tab-width: 4
// c-file-style: "bsd"
// c-basic-offset: 4
// fill-column: 108
// indent-tabs-mode: nil
// End:
