/*-
 * Copyright (c) 2000, Shunsuke Akiyama <akiyama@FreeBSD.org>.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *
 *	$Id: sockspipe.c,v 1.5 2000/07/20 14:07:36 akiyama Exp $
 */

#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <unistd.h>
#include <string.h>
#include <signal.h>
#include <errno.h>
#include <sys/time.h>
#include <time.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <pthread.h>

#include "sockspipe.h"

extern char *optarg;
extern int optind;

char *prog = "sockspipe";
int f_debug;
int f_logging;
int f_verbose;
char *logfile;
FILE *lfout;
int io_term;			/* this must initialize by caller */
pthread_t threads[2];
int retval[2];
char iobuffer[2][IOSIZE];

void
d_printf(char *fmt, ...)
{
    va_list ap;

    if (f_debug) {
	va_start(ap, fmt);
	vfprintf(lfout, fmt, ap);
	va_end(ap);
    }
}

void
v_printf(char *fmt, ...)
{
    va_list ap;

    if (f_verbose != 0 || f_debug != 0) {
	va_start(ap, fmt);
	vfprintf(lfout, fmt, ap);
	va_end(ap);
    }
}

int
relay_io(src, dst, ibuf, bsize)
    int src;
    int dst;
    char *ibuf;
    size_t bsize;
{
    int maxfd;
    fd_set fdset;
    int len, n;
    char *buf;

    maxfd = src  + 1;

    while (!io_term) {
	errno = 0;

	FD_ZERO(&fdset);
	FD_SET(src, &fdset);

	if ((len = select(maxfd, &fdset, NULL, NULL, NULL)) < 0) {
	    if (errno != EAGAIN && errno != EINTR)
		return -1;
	    else
		continue;
	}
	if (len == 0) continue;

	if (FD_ISSET(src, &fdset) == 0) continue;

	/* Read data from source descriptor */
	len = read(src, ibuf, bsize);
	if (len < 0) {
	    if (errno != EAGAIN && errno != EINTR)
		return -1;
	    else
		continue;
	}
	else if (len == 0) {
	    /* connection closed */
	    break;
	}

	/* Write data to destination descriptor */
	for (buf = ibuf; len > 0; len -= n) {
	    n = write(dst, buf, len);
	    if (n < 0) {
		if (errno != EAGAIN && errno != EINTR)
		    return -1;
	    }
	    buf += n;
	}
    }

    return 0;
}

void *
relay_in2net(arg)
    void *arg;
{
    int sd;

    d_printf("%s(#1): stdin to network redirecting thread started\n", prog);

    sd = *((int *)arg);
    retval[0] = relay_io(0, sd, iobuffer[0], IOSIZE);

    if (retval[0] < 0) {
      fprintf(lfout, "%s(#1): %s\n", prog, strerror(errno));
    }
    d_printf("%s: stdin to network redirecting thread (#1) stopped\n", prog);

    pthread_exit(&retval[0]);
    return &retval[0];
}

void *
relay_net2out(arg)
    void *arg;
{
    int sd;

    d_printf("%s(#2): network to stdout redirecting thread started\n", prog);

    sd = *((int *)arg);
    retval[1] = relay_io(sd, 1, iobuffer[1], IOSIZE);

    if (retval[1] < 0) {
      fprintf(lfout, "%s(#1): %s\n", prog, strerror(errno));
    }
    d_printf("%s(#2): network to stdout redirecting thread stopped\n", prog);

    pthread_exit(&retval[1]);
    return &retval[1];
}

void
set_signals()
{
    /* trap signals */
    signal(SIGHUP, SIG_DFL);
    signal(SIGINT, SIG_DFL);
    signal(SIGQUIT, SIG_DFL);
    signal(SIGPIPE, SIG_DFL);
    signal(SIGTERM, SIG_DFL);
}

int
create_connection(sin, hostname, portnum)
    struct sockaddr_in *sin;
    char *hostname;
    char *portnum;
{
    int sd;
    struct hostent *host;
    struct servent *sp;
    unsigned long temp;
    int port;

    sd = socket(AF_INET, SOCK_STREAM, 0);
    if (sd < 0) {
	fprintf(lfout, "%s: %s\n", prog, strerror(errno));
	exit(1);
    }

    memset(sin, 0, sizeof(*sin));

    host = gethostbyname(hostname);
    if (host) {
	sin->sin_family = host->h_addrtype;
	memcpy(&sin->sin_addr, host->h_addr, host->h_length);
    } else {
	temp = inet_addr(hostname);
	if (temp != (unsigned long) -1) {
	    sin->sin_addr.s_addr = temp;
	    sin->sin_family = AF_INET;
	} else {
	    fprintf(lfout, "%s: %s: %s\n", prog, hostname, hstrerror(h_errno));
	    exit(1);
	}
    }

    port = atoi(portnum);
    if (port == 0) {
	sp = getservbyname(portnum, "tcp");
	if (sp == NULL) {
	    fprintf(lfout, "%s: %s: %s\n", prog, portnum, hstrerror(h_errno));
	    exit(1);
	}
	port = sp->s_port;
    } else {
	port = htons(port);
    }
    sin->sin_port = port;

    v_printf("%s: connecting to %s port %d\n",
	     prog, hostname, (int)htons(port));

    if (connect(sd, (struct sockaddr *)sin, sizeof (*sin)) < 0) {
	fprintf(lfout, "%s: %s\n", prog, strerror(errno));
	exit(1);
    }

    return sd;
}

void
usage()
{
    fprintf(lfout, "%s: socks proxy for OpenSSH, version %s\n", prog, VERSION);
    fprintf(lfout, "usage: %s [-l logfile] [-v] host port\n", prog);
}

void
parse_options(argc, argv)
    int argc;
    char **argv;
{
    int ch;

    /* initialize option variables */
    f_debug = 0;
    f_logging = 0;
    f_verbose = 0;
    lfout = stderr;

    /* parse command line options */
    while ((ch = getopt(argc, argv, "dl:v?")) != -1) {
	switch (ch) {
	case 'd' :
	    f_debug = 1;
	    break;
	case 'l' :
	    f_logging = 1;
	    logfile = optarg;
	    lfout = fopen(logfile, "w+");
	    if (lfout == NULL) {
		perror(prog);
		exit(1);
	    }
	    break;
	case 'v' :
	    f_verbose = 1;
	    break;
	case '?' :
	default :
	    usage();
	    exit(1);
	}
    }

    d_printf("%s: debug %s\n", prog, (f_debug ? "on" : "off"));
    d_printf("%s: logging %s (%s)\n",
	     prog, (f_logging ? "on" : "off"), logfile);
    d_printf("%s: verbose %s\n", prog, (f_verbose ? "on" : "off"));
}

int
main(argc, argv)
    int argc;
    char **argv;
{
    char *host;
    char *port;
    int sd;
    struct sockaddr_in sin;

    io_term = 0;

    /* option and arguments handling */
    parse_options(argc, argv);

    argc -= optind;
    argv += optind;
    if (argc != 2) {
	usage();
	exit(1);
    }

    host = *argv;
    port = *(argv + 1);

    /* signal handling */
    set_signals();

    /* create socket and connect to */
    sd = create_connection(&sin, host, port);

    /* create threads */
    if (pthread_create(&threads[0], NULL, relay_in2net, &sd) < 0) {
	fprintf(lfout, "%s: %s\n", prog, strerror(errno));
	exit(1);
    }
    if (pthread_create(&threads[1], NULL, relay_net2out, &sd) < 0) {
	fprintf(lfout, "%s: %s\n", prog, strerror(errno));
	exit(1);
    }

    /* wait for threads termination */
    pthread_join(threads[0], NULL);
    pthread_join(threads[1], NULL);

    /* shutdown connection */
    shutdown(sd, SHUT_RDWR);
    close(sd);

    /* close log */
    if (f_logging)
	fclose(lfout);

    exit(0);
}
