/*
 * policy.h
 *
 * Common functions for handling TOMOYO Linux's domain policy.
 *
 * Copyright (C) 2005-2006  NTT DATA CORPORATION
 *
 * Version: 1.3.1   2006/12/08
 *
 */
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <errno.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <sys/socket.h>
#include <sys/un.h>

static void OutOfMemory(void) {
	fprintf(stderr, "Out of memory. Aborted.\n");
	exit(1);
}

/* Copied from kernel source. */
static inline unsigned long partial_name_hash(unsigned long c, unsigned long prevhash) {
	return (prevhash + (c << 4) + (c >> 4)) * 11;
}

/* Copied from kernel source. */
static inline unsigned int full_name_hash(const unsigned char *name, unsigned int len) {
	unsigned long hash = 0;
	while (len--) hash = partial_name_hash(*name++, hash);
	return (unsigned int) hash;
}

#define PAGE_SIZE  4096

static char *alloc_element(const unsigned int size) {
	static char *buf = NULL;
	static unsigned int buf_used_len = PAGE_SIZE;
	char *ptr = NULL;
	if (size > PAGE_SIZE) return NULL;
	if (buf_used_len + size > PAGE_SIZE) {
		if ((ptr = malloc(PAGE_SIZE)) == NULL) OutOfMemory();
		buf = ptr;
		memset(buf, 0, PAGE_SIZE);
		buf_used_len = size;
		ptr = buf;
	} else if (size) {
		int i;
		ptr = buf + buf_used_len;
		buf_used_len += size;
		for (i = 0; i < size; i++) if (ptr[i]) OutOfMemory();
	}
	return ptr;
}

#define MAX_HASH 256

typedef struct name_entry {
	struct name_entry *next; /* Pointer to next record. NULL if none.             */
	unsigned int hash;       /* hash and length                                   */
	const char *name;        /* Text form of filename and domainname. Never NULL. */
} NAME_ENTRY;

typedef struct free_memory_block_list {
	struct free_memory_block_list *next; /* Pointer to next record. NULL if none. */
	char *ptr;                           /* Pointer to a free area.               */
	int len;                             /* Length of the area.                   */
} FREE_MEMORY_BLOCK_LIST;

static const char *SaveName(const char *name) {
	static FREE_MEMORY_BLOCK_LIST fmb_list = { NULL, NULL, 0 };
	static NAME_ENTRY name_list[MAX_HASH]; /* The list of names. */
	NAME_ENTRY *ptr, *prev = NULL;
	unsigned int hash;
	FREE_MEMORY_BLOCK_LIST *fmb = &fmb_list;
	int len;
	static int first_call = 1;
	if (!name) return NULL;
	len = strlen(name) + 1;
	if (len > PAGE_SIZE) {
		fprintf(stderr, "ERROR: Name too long for SaveName().\n");
		return NULL;
	}
	hash = full_name_hash((const unsigned char *) name, len - 1);
	if (first_call) {
		int i;
		first_call = 0;
		memset(&name_list, 0, sizeof(name_list));
		for (i = 0; i < MAX_HASH; i++) name_list[i].name = "/";
	}
	ptr = &name_list[hash % MAX_HASH];
	hash ^= len; /* The hash % MAX_HASH are always same for ptr->hash, so embed length into the hash value. */
	while (ptr) {
		if (hash == ptr->hash && strcmp(name, ptr->name) == 0) goto out;
		prev = ptr; ptr = ptr->next;
	}
	while (len > fmb->len) {
		if (fmb->next) {
			fmb = fmb->next;
		} else {
			char *cp;
			if ((cp = (char *) malloc(PAGE_SIZE)) == NULL || (fmb->next = (FREE_MEMORY_BLOCK_LIST *) alloc_element(sizeof(FREE_MEMORY_BLOCK_LIST))) == NULL) OutOfMemory();
			memset(cp, 0, PAGE_SIZE);
			fmb = fmb->next;
			fmb->ptr = cp;
			fmb->len = PAGE_SIZE;
		}
	}
	if ((ptr = (NAME_ENTRY *) alloc_element(sizeof(NAME_ENTRY))) == NULL) OutOfMemory();
	memset(ptr, 0, sizeof(NAME_ENTRY));
	ptr->hash = hash;
	ptr->name = fmb->ptr;
	memmove(fmb->ptr, name, len);
	fmb->ptr += len;
	fmb->len -= len;
	prev->next = ptr; /* prev != NULL because name_list is not empty. */
	if (fmb->len == 0) {
		FREE_MEMORY_BLOCK_LIST *ptr = &fmb_list;
		while (ptr->next != fmb) ptr = ptr->next; ptr->next = fmb->next;
	}
 out:
	return ptr ? (const char *) ptr->name : NULL;
}

#define ROOT_NAME "<kernel>"

static int IsCorrectDomain(const unsigned char *domainname) {
	unsigned char c, d, e;
	if (!domainname || strncmp(domainname, ROOT_NAME, strlen(ROOT_NAME))) goto out;
	domainname += strlen(ROOT_NAME);
	if (!*domainname) return 1;
	do {
		if (*domainname++ != ' ') goto out;
		if (*domainname++ != '/') goto out;
		while ((c = *domainname) != '\0' && c != ' ') {
			domainname++;
			if (c == '\\') {
				switch ((c = *domainname++)) {
				case '\\':  /* "\\" */
					continue;
				case '0':   /* "\ooo" */
				case '1':
				case '2':
				case '3':
					if ((d = *domainname++) >= '0' && d <= '7' && (e = *domainname++) >= '0' && e <= '7') {
						const unsigned char f =
							(((unsigned char) (c - '0')) << 6) +
							(((unsigned char) (d - '0')) << 3) +
							(((unsigned char) (e - '0')));
						if (f && (f <= ' ' || f >= 127)) continue; /* pattern is not \000 */
					}
				}
				goto out;
			} else if (c < ' ' || c >= 127) {
				goto out;
			}
		}
	} while (*domainname);
	return 1;
 out:
	return 0;
}

typedef struct domain_info {
	const char *domainname;
	const char **string_ptr;
	int string_count;
} DOMAIN_INFO;

static DOMAIN_INFO *domain_list[2] = { NULL, NULL };
static int domain_list_count[2] = { 0, 0 };

static int AddStringEntry(const char *entry, const int index, const int type) {
	const char **acl_ptr;
	int acl_count;
	const char *cp;
	int i;
	if (index < 0 || index >= domain_list_count[type]) {
		fprintf(stderr, "AddStringEntry: ERROR: domain is out of range.\n");
		return -EINVAL;
	}
	if (!entry || !*entry) return -EINVAL;
	if ((cp = SaveName(entry)) == NULL) OutOfMemory();

	acl_ptr = domain_list[type][index].string_ptr;
	acl_count = domain_list[type][index].string_count;

	// Check for the same entry.
	for (i = 0; i < acl_count; i++) {
		// Faster comparison, for they are SaveName'd.
		if (cp == acl_ptr[i]) return 0;
	}

	if ((acl_ptr = (const char **) realloc(acl_ptr, (acl_count + 1) * sizeof(const char *))) == NULL) OutOfMemory();
	acl_ptr[acl_count++] = cp;
	domain_list[type][index].string_ptr = acl_ptr;
	domain_list[type][index].string_count = acl_count;
	return 0;
}

static int DelStringEntry(const char *entry, const int index, const int type) {
	const char **acl_ptr;
	int acl_count;
	const char *cp;
	int i;
	if (index < 0 || index >= domain_list_count[type]) {
		fprintf(stderr, "DelStringEntry: ERROR: domain is out of range.\n");
		return -EINVAL;
	}
	if (!entry || !*entry) return -EINVAL;
	if ((cp = SaveName(entry)) == NULL) OutOfMemory();

	acl_ptr = domain_list[type][index].string_ptr;
	acl_count = domain_list[type][index].string_count;

	for (i = 0; i < acl_count; i++) {
		// Faster comparison, for they are SaveName'd.
		if (cp != acl_ptr[i]) continue;
		domain_list[type][index].string_count--;
		for (; i < acl_count - 1; i++) acl_ptr[i] = acl_ptr[i + 1];
		return 0;
	}
	return -ENOENT;
}

static int FindDomain(const char *domainname, const int type) {
	int i;
	for (i = 0; i < domain_list_count[type]; i++) {
		if (strcmp(domainname, domain_list[type][i].domainname) == 0) {
			return i;
		}
	}
	return EOF;
}

static int FindOrAssignNewDomain(const char *domainname, const int type) {
	const char *saved_domainname;
	int index;
	if ((index = FindDomain(domainname, type)) == EOF) {
		if (IsCorrectDomain(domainname)) {
			if ((domain_list[type] = (DOMAIN_INFO *) realloc(domain_list[type], (domain_list_count[type] + 1) * sizeof(DOMAIN_INFO))) == NULL) OutOfMemory();
			memset(&domain_list[type][domain_list_count[type]], 0, sizeof(DOMAIN_INFO));
			if ((saved_domainname = SaveName(domainname)) == NULL) OutOfMemory();
			domain_list[type][domain_list_count[type]].domainname = saved_domainname;
			index = domain_list_count[type]++;
		} else {
			fprintf(stderr, "FindOrAssignNewDomain: Invalid domainname '%s'\n", domainname);
		}
	}
	return index;
}

static void DeleteDomain(const int index, const int type) {
	if (index > 0 && index < domain_list_count[type]) {
		int i;
		free(domain_list[type][index].string_ptr);
		for (i = index; i < domain_list_count[type] - 1; i++) domain_list[type][i] = domain_list[type][i + 1];
		domain_list_count[type]--;
	}
}

static int IsDomainDef(const unsigned char *buffer) {
	while (*buffer && (*buffer <= 32 || 127 <= *buffer)) buffer++;
	return strncmp(buffer, ROOT_NAME, strlen(ROOT_NAME)) == 0;
}

static void NormalizeLine(unsigned char *buffer) {
	unsigned char *sp = buffer, *dp = buffer;
	int first = 1;
	while (*sp && (*sp <= 32 || 127 <= *sp)) sp++;
	while (*sp) {
		if (!first) *dp++ = ' ';
		first = 0;
		while (32 < *sp && *sp < 127) *dp++ = *sp++;
		while (*sp && (*sp <= 32 || 127 <= *sp)) sp++;
	}
	*dp = '\0';
}

static void SortPolicy(const int type) {
	int i, j, k;
	for (i = 0; i < domain_list_count[type]; i++) {
		for (j = i + 1; j < domain_list_count[type]; j++) {
			if (strcmp(domain_list[type][i].domainname, domain_list[type][j].domainname) > 0) {
				DOMAIN_INFO tmp = domain_list[type][i]; domain_list[type][i] = domain_list[type][j]; domain_list[type][j] = tmp;
			}
		}
	}
	for (i = 0; i < domain_list_count[type]; i++) {
		const char **string_ptr = domain_list[type][i].string_ptr;
		const int string_count = domain_list[type][i].string_count;
		for (j = 0; j < string_count; j++) {
			for (k = j + 1; k < string_count; k++) {
				const char *a = string_ptr[j];
				const char *b = string_ptr[k];
				if (*a && *b && strcmp(a + 1, b + 1) > 0) {
					string_ptr[j] = b; string_ptr[k] = a;
				}
			}
		}
	}
}

#define MAXBUFSIZE  8192

static char buffer[MAXBUFSIZE];

static void ReadDomainPolicy(const char *filename, const int type) {
	FILE *fp = stdin;
	int index;
	if (filename) {
		if ((fp = fopen(filename, "r")) == NULL) {
			fprintf(stderr, "Can't open %s\n", filename);
			return;
		}
	}
	index = EOF;
	while (memset(buffer, 0, sizeof(buffer)), fgets(buffer, sizeof(buffer) - 1, fp) != NULL) {
		char *cp = strchr(buffer, '\n');
		if (cp) *cp = '\0';
		else if (!feof(fp)) break;
		NormalizeLine(buffer);
		if (IsDomainDef(buffer)) {
			index = FindOrAssignNewDomain(buffer, type);
		} else if (index >= 0 && buffer[0]) {
			AddStringEntry(buffer, index, type);
		}
	}
	if (fp != stdin) fclose(fp);
	SortPolicy(type);
}

static int WriteDomainPolicy(const int fd, const int type) {
	int i, j;
	for (i = 0; i < domain_list_count[type]; i++) {
		const char **string_ptr = domain_list[type][i].string_ptr;
		const int string_count = domain_list[type][i].string_count;
		int len = strlen(domain_list[0][i].domainname);
		write(fd, domain_list[type][i].domainname, len);
		write(fd, "\n\n", 2);
		for (j = 0; j < string_count; j++) {
			len = strlen(string_ptr[j]);
			write(fd, string_ptr[j], len);
			write(fd, "\n", 1);
		}
		write(fd, "\n", 1);
	}
	return 0;
}
