/*
 * ESE, a HyperText Transfer Protocol server
 * Copyright (C) 1996-2001 Akira Higuchi <a-higuti@math.sci.hokudai.ac.jp>
 * All rights reserved.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

#include <sys/socket.h>
#include <netinet/in.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <netinet/tcp.h>
#include <sys/time.h>
#include <signal.h>
#include <netdb.h>
#include <unistd.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <fcntl.h>
#include <stdarg.h>

#include <vector>
#include <map>
#include <string>
#include <iostream>


#define DEBUG_DO(x)

using namespace std;

static string format_string (const char *format, ...)
{
  char *buffer;
  va_list ap;
  va_start (ap, format);
  if (vasprintf (&buffer, format, ap) < 0) {
    abort ();
  }
  va_end (ap);
  string str (buffer);
  free (buffer);
  return str;
}

class ec_errortable {
  map <string, int> table;
public:
  void perror (string arg);
  void append (string arg);
  void dump (void);
};

void
ec_errortable::append (string arg)
{
  table[arg]++;
}

void
ec_errortable::perror (string arg)
{
  arg += ": ";
  arg += strerror (errno);
  table[arg]++;
}

void
ec_errortable::dump (void)
{
  if (table.size () == 0)
    return;
  cout << "Errors:\n";
  map <string, int>::iterator it;
  for (it = table.begin (); it != table.end (); it++) {
    cout << format_string ("%s (%d times)\n", it->first.c_str (), it->second);
  }
  cout << "\n";
}

class ec_connection {
  int fd;
  ////////////////////////////////////////////////////////////////
  vector <char> readbuf;
  vector <char> writebuf;
  size_t read_len, write_len;
  size_t response_header_offset;
  size_t response_entity_length;
  size_t response_entity_offset;
  int expect_response_count;
  int got_response_count;
  ////////////////////////////////////////////////////////////////
  size_t total_read_len, total_write_len;
  int total_response_count;
  int bad_response_count;
  int non_200_count;
  int io_error_count;
  int connect_is_in_progress;
  int connection_refused_count;
  int method_is_head;
  map <string, int> non_200_map;
  string document_length;
  string server_software;
  size_t find_empty_line (const char *first, const char *last);
  vector <string> get_header_line (size_t& offset);
  ec_errortable *errortable;
  void parse_header (void);
  void parse_response (void);
public:
  ec_connection () :
    fd (-1), read_len (0), write_len (0),
    response_header_offset (0),
    response_entity_length (0), response_entity_offset (0),
    expect_response_count (0), got_response_count (0),
    total_read_len (0), total_write_len (0), total_response_count (0),
    bad_response_count (0), non_200_count (0), io_error_count (0),
    connect_is_in_progress (0), connection_refused_count (0),
    method_is_head (0) { }
  void start (ec_errortable *app_errortable,
	      const struct sockaddr_in& servaddr, int is_head);
  void append_send_queue (const vector <char>& buf);
  void pre_select (int& fdmax, fd_set& rfds, fd_set& wfds);
  void post_select (fd_set& rfds, fd_set& wfds);
  void close (void);
  bool is_closed (void) const { return fd < 0; }
  int get_non_200_count () const { return non_200_count; }
  const map <string, int>& get_non_200_map () const { return non_200_map; }
  int get_io_error_count () const { return io_error_count; }
  int get_bad_response_count () const { return bad_response_count; }
  int get_total_response_count () const { return total_response_count; }
  int get_connection_refused_count () const {
    return connection_refused_count;
  }
  size_t get_total_read_len () const { return total_read_len; }
  size_t get_total_write_len () const { return total_write_len; }
  string get_document_length () const { return document_length; }
  string get_server_software () const { return server_software; }
};

size_t
ec_connection::find_empty_line (const char * first, const char *last)
{
  size_t len = 0;
  last -= 3;
  while (first != last) {
    if (*first == '\r' && *(first + 1) == '\n' &&
	*(first + 2) == '\r' && *(first + 3) == '\n')
      return len + 4;
    first++;
    len++;
  }
  return 0;
}

vector <string>
ec_connection::get_header_line (size_t& offset)
{
  vector <string> v;
  size_t word_beginning = offset;
  while (offset < response_entity_offset) {
    while (offset < response_entity_offset &&
	   isspace (readbuf[offset]))
      offset++;
    word_beginning = offset;
    while (offset < response_entity_offset &&
	   !isspace (readbuf[offset]) &&
	   readbuf[offset] >= 32)
      offset++;
    if (offset > word_beginning) {
      string s(&readbuf[word_beginning], offset - word_beginning);
      v.push_back (s);
    }
    if (readbuf[offset] == '\r' || readbuf[offset] == '\n') {
      while (offset < response_entity_offset &&
	     (readbuf[offset] == '\r' || readbuf[offset] == '\n'))
	offset++;
      return v;
    }
  }
  return v;
}

void
ec_connection::parse_header (void)
{
  DEBUG_DO (printf ("parse_header\n"));
  vector <string> line;
  size_t offset = response_header_offset;
  line = get_header_line (offset);
  if (line.size () > 0) {
    DEBUG_DO (printf ("HTTP response: %s\n", line[0].c_str ()));
  }
  if (line.size () < 3) {
    bad_response_count++;
    return;
  }
  if (line[1] != "200") {
    non_200_map[line[1]]++;
    non_200_count++;
#if 0
    cout << "total_response_count: " << total_response_count << endl;
    cout << string(writebuf.begin(), writebuf.end()).c_str() << endl;
#endif
  }
  while (1) {
    line = get_header_line (offset);
    if (line.size () == 0)
      break;
    else if (line.size () < 2) {
      continue;
    }
    const char *header = line[0].c_str ();
    if (strcasecmp (header, "content-length:") == 0) {
      response_entity_length = atol (line[1].c_str ());
      if (document_length == "")
	document_length = line[1];
    } else if (server_software == "" && strcasecmp (header, "server:") == 0) {
      server_software = line[1];
    } else if (strcasecmp (header, "transfer-encoding:") == 0) {
      throw string ("esebench does not understand transfer-encoding field");
    }
  }
  if (document_length == "") {
    throw string ("document length is unknown");
  }
  DEBUG_DO (printf ("HTTP response: length = %d\n", response_entity_length));
  if (method_is_head)
    response_entity_length = 0;
}

void
ec_connection::parse_response (void)
{
  DEBUG_DO (printf ("parse_response: %s\n", &readbuf[response_header_offset]));
  if (read_len == 0)
    return;
  while (1) {
    if (response_entity_length > 0) {
      /* we're reading response entity */
      if (read_len >= response_entity_offset + response_entity_length) {
	/* we've got the whole body */
	response_header_offset = response_entity_offset +
	  response_entity_length;
	response_entity_offset = response_header_offset;
	response_entity_length = 0;
	got_response_count++;
	total_response_count++;
      } else {
	/* body still continues */
	return;
      }
    } else {
      /* we're reading response header */
      DEBUG_DO (printf ("reading header: %d %d\n",
			response_header_offset, read_len));
      if (read_len < response_header_offset + 4)
	return;
      size_t p = find_empty_line (&readbuf[response_header_offset],
				  &readbuf[read_len]);
      DEBUG_DO (printf ("empty line offset: %d\n", p));
      if (p > 0) {
	response_entity_offset = response_header_offset + p;
	parse_header ();
	if (response_entity_length == 0) {
	  /* body is empty */
	  response_header_offset = response_entity_offset;
	  got_response_count++;
	  total_response_count++;
	}
      } else {
	/* response header still continues */
	return;
      }
    }
  }
}

void
ec_connection::start (ec_errortable *app_errortable,
		      const struct sockaddr_in& servaddr, int is_head)
{
  DEBUG_DO (printf ("start\n"));
  assert (fd < 0);
  errortable = app_errortable;
  connect_is_in_progress = 0;
  readbuf.erase (readbuf.begin (), readbuf.end ());
  writebuf.erase (writebuf.begin (), writebuf.end ());
  read_len = write_len = 0;
  response_header_offset = 0;
  response_entity_length = 0;
  response_entity_offset = 0;
  expect_response_count = 0;
  got_response_count = 0;
  method_is_head = is_head;
  int v;
  try {
    if ((fd = socket (AF_INET, SOCK_STREAM, 0)) < 0)
      throw ("socket");
    if (fcntl (fd, F_SETFL, O_RDWR | O_NONBLOCK) < 0)
      throw ("fcntl F_SETFL O_RDWR | O_NONBLOCK");
    v = 1;
    if (setsockopt (fd, IPPROTO_TCP, TCP_NODELAY, &v, sizeof (v)) < 0)
      throw ("setsockopt SO_NODELAY");
    v = 1;
    if (setsockopt (fd, SOL_SOCKET, SO_KEEPALIVE, &v, sizeof (v)) < 0)
      throw ("setsockopt SO_KEEPALIVE");
#if 1
    v = 16384;
    if (setsockopt (fd, SOL_SOCKET, SO_SNDBUF, &v, sizeof (v)) < 0)
      throw ("setsockopt SO_SNDBUF");
#endif
    int r = connect (fd, (struct sockaddr *)&servaddr, sizeof (servaddr));
    if (r < 0) {
      if (errno == EINPROGRESS) {
	connect_is_in_progress = 1;
      } else {
	connection_refused_count++;
	throw ("connect");
      }
    } else {
      connect_is_in_progress = 0;
    }
  } catch (const char *s) {
    errortable->perror (s);
    close ();
  }
}

void
ec_connection::append_send_queue (const vector <char>& buf)
{
  DEBUG_DO (printf ("append %s", &buf[0]));
  expect_response_count++;
  writebuf.insert (writebuf.end (), buf.begin (), buf.end ());
}

void
ec_connection::pre_select (int& fdmax, fd_set& rfds, fd_set& wfds)
{
  if (fd < 0)
    return;
  fdmax = (fd > fdmax) ? fd : fdmax;
  DEBUG_DO (printf ("fd %d: read\n", fd));
  if (!connect_is_in_progress)
    FD_SET (fd, &rfds);
  if (connect_is_in_progress || writebuf.size () > write_len) {
    FD_SET (fd, &wfds);
    DEBUG_DO (printf ("fd %d: write\n", fd));
  }
}

void
ec_connection::post_select (fd_set& rfds, fd_set& wfds)
{
  if (fd >= 0 && FD_ISSET (fd, &rfds)) {
    const size_t readmax = 8192;
    readbuf.resize (read_len + readmax);
    int r = read (fd, &readbuf[read_len], readmax);
    if (r == 0) {
      DEBUG_DO (printf ("read EOF\n"));
      if (got_response_count < expect_response_count)
        io_error_count++;
      close ();
    } else if (r < 0) {
      if (errno != EINTR && errno != EWOULDBLOCK) {
	errortable->perror ("read");
	io_error_count++;
	close ();
      }
    } else {
      DEBUG_DO (printf ("read %d bytes\n", r));
      read_len += r;
      total_read_len += r;
      try {
	parse_response ();
      }
      catch (string str) {
	errortable->append (str);
	close ();
      }
    }
  }
  if (fd >= 0 && FD_ISSET (fd, &wfds)) {
    if (connect_is_in_progress) {
      connect_is_in_progress = 0;
      int v;
      socklen_t optlen = sizeof (v);
      int r = getsockopt (fd, SOL_SOCKET, SO_ERROR, &v, &optlen);
      if (r < 0 || v < 0) {
	fd = -1;
	connection_refused_count++;
      }
    }
    if (fd >= 0) {
      size_t writemax = writebuf.size () - write_len;
      DEBUG_DO (printf ("writemax = %d\n", writemax));
      int r = write (fd, &writebuf[write_len], writemax);
      if (r == 0) {
	DEBUG_DO (printf ("write returns 0\n"));
	io_error_count++;
	close ();
      } else if (r < 0) {
	if (errno != EINTR && errno != EWOULDBLOCK) {
	  errortable->perror ("write");
	  io_error_count++;
	  close ();
	}
      } else {
	DEBUG_DO (printf ("write %d bytes\n", r));
	write_len += r;
	total_write_len += r;
	if (write_len == writebuf.size ()) {
	  //	  shutdown (fd, SHUT_WR);
	  DEBUG_DO (printf ("shutdown\n"));
	}
      }
    }
  }
}

void
ec_connection::close (void)
{
  if (fd >= 1)
    ::close (fd);
  fd = -1;
}

class ec_app {
  vector <ec_connection> conns;
  struct sockaddr_in servaddr;
  map <string, string> options;
  ec_errortable errortable;
  int pre_select (int& fdmax, fd_set& rfds, fd_set& wfds);
  void post_select (fd_set& rfds, fd_set& wfds);
  void append_connection (string& request, int num_requests);
  int nconns, concurrency, nrequests;
  int nconns_done;
  int timeout_sec;
  int idtag;
  time_t start_at;
  string method;
  string hostname;
  string uri;
  string request, request_last;
  void show_options (void);
  void fill_connection (void);
  void show_summary (struct timeval& timediff);
public:
  ec_app (const map <string, string>& options_arg);
  void run (void);
};

void
ec_app::fill_connection (void)
{
  vector <ec_connection>::iterator it;
  for (it = conns.begin ();
       it != conns.end () && (nconns == 0 || nconns_done < nconns);
       it++) {
    if (it->is_closed ()) {
      nconns_done++;
      it->start (&errortable, servaddr, options["method"] == "HEAD");
      for (int i = 0; i < nrequests; i++) {
	static int id = 0;
        vector <char> req;
	char buf[4097];
	if (idtag) {
	  snprintf(buf, 4096, 
		   "%s %s HTTP/1.1\r\n"
		   "Host: %s\r\n"
		   "%s"
		   "ID: %08x\r\n"
		   "\r\n",
		   method.c_str (), uri.c_str (),
		   hostname.c_str (),
		   (i == nrequests - 1) ? "Connection: close\r\n" : "",
		   id++);
	} else {
	  snprintf(buf, 4096, 
		   "%s %s HTTP/1.1\r\n"
		   "Host: %s\r\n"
		   "%s"
		   "\r\n",
		   method.c_str (), uri.c_str (),
		   hostname.c_str (),
		   (i == nrequests - 1) ? "Connection: close\r\n" : "");
	}
	req.insert(req.begin(), buf, buf + strlen(buf));
	it->append_send_queue (req);
      }
    }
  }
}

void
ec_app::show_options (void)
{
  static const char *const opts[] = {
    "server", "port", "method", "uri", "connections", "requests",
    "concurrency", "timeout", "at", "idtag",
  };
  cout << "\neseclient";
  for (size_t i = 0; i < sizeof (opts) / sizeof (opts[0]); i++) {
    cout << format_string (" %s=%s", opts[i], options[opts[i]].c_str ());
  }
  cout << "\n";
}

ec_app::ec_app (const map <string, string>& options_arg)
{
  options = options_arg;
  string portnum;
  hostname = options["server"];
  if (hostname == "") {
    hostname = "localhost";
    options["server"] = "localhost";
  }
  portnum = options["port"];
  if (portnum == "") {
    portnum = "80";
    options["port"] = "80";
  }
  struct hostent *hp;
  if ((hp = gethostbyname (hostname.c_str ())) == NULL) {
    herror ("gethostbyname");
    exit (1);
  }
  memset (&servaddr, 0, sizeof (struct sockaddr_in));
  servaddr.sin_family = AF_INET;
  servaddr.sin_port = htons (atoi (portnum.c_str ()));
  memcpy (&servaddr.sin_addr, hp->h_addr, hp->h_length);
  uri = options["uri"];
  if (uri == "") {
    uri = "/";
    options["uri"] = "/";
  }
  nconns = atoi (options["connections"].c_str ());
  if (nconns < 1) {
    nconns = 1;
    options["connections"] = "1";
  }
  concurrency = atoi (options["concurrency"].c_str ());
  if (concurrency < 1) {
    concurrency = 1;
    options["concurrency"] = "1";
  }
  nrequests = atoi (options["requests"].c_str ());
  if (nrequests < 1) {
    nrequests = 1;
    options["requests"] = "1";
  }
  method = options["method"];
  if (method == "") {
    method = "GET";
    options["method"] = "GET";
  }
  string timeout = options["timeout"];
  timeout_sec = atoi (timeout.c_str ());
  if (timeout_sec <= 0) {
    timeout_sec = 0;
    options["timeout"] = "";
  } else {
    nconns = 0;
    options["connections"] = "0";
  }
  string start_time = options["at"];
  start_at = atol (options["at"].c_str ());
  if (start_at <= 0) {
    start_at = 0;
    options["at"] = "";
  }
  idtag = (options["idtag"] != "") ? 1 : 0;
  options["idtag"] = idtag ? "yes" : "no";
  show_options ();
  conns.resize (concurrency);
  nconns_done = 0;
}

int
ec_app::pre_select (int& fdmax, fd_set& rfds, fd_set& wfds)
{
  fdmax = -1;
  FD_ZERO (&rfds);
  FD_ZERO (&wfds);
  vector <ec_connection>::iterator it;
  fill_connection ();
  int n;
  for (n = 0, it = conns.begin (); it != conns.end (); it++) {
    it->pre_select (fdmax, rfds, wfds);
    if (it->is_closed ())
      n++;
  }
  return (n == concurrency);
}

void
ec_app::post_select (fd_set& rfds, fd_set& wfds)
{
  vector <ec_connection>::iterator it;
  for (it = conns.begin (); it != conns.end (); it++) {
    it->post_select (rfds, wfds);
  }
}

void
ec_app::run (void)
{
  fd_set rfds, wfds;
  int fdmax;
  struct timeval time1, time2, timediff, timeout;
  if (start_at) {
    time_t now;
    while (1) {
      now = time (NULL);
      if (now >= start_at)
	break;
      usleep (0);
    }
    cout << "start\n";
  }
  cout.flush ();
  gettimeofday (&time1, NULL);
  timeout = time1;
  timeout.tv_sec += timeout_sec;
  while (1) {
    struct timeval tv;
    fd_set rfds_bak, wfds_bak;
    if (pre_select (fdmax, rfds, wfds))
      break;
    if (timeout_sec) {
      struct timeval now;
      gettimeofday (&now, NULL);
      if (!timercmp (&timeout, &now, >))
	break;
      timersub (&timeout, &now, &tv);
    } else {
      tv.tv_sec = 1;
      tv.tv_usec = 0;
    }
    rfds_bak = rfds;
    wfds_bak = wfds;
    if (select (fdmax + 1, &rfds, &wfds, NULL, &tv) <= 0) {
#if 0
      cout << "select ([";
      int i;
      for (i = 0; i < 1024; i++) {
	if (FD_ISSET (i, &rfds_bak))
	  cout << format_string (" %d", i);
      }
      cout << " ] [";
      for (i = 0; i < 1024; i++) {
	if (FD_ISSET (i, &wfds_bak))
	  cout << format_string (" %d", i);
      }
      cout << " ])\n";
#endif
    }
    post_select (rfds, wfds);
  }
  gettimeofday (&time2, NULL);
  timersub (&time2, &time1, &timediff);
  show_summary (timediff);
  cout << "\n";
  errortable.dump ();
}

void
ec_app::show_summary (struct timeval& timediff)
{
  vector <ec_connection>::iterator it;
  int non_200_count = 0;
  int io_error_count = 0;
  int bad_response_count = 0;
  int total_response_count = 0;
  int connection_refused_count = 0;
  map <string, int> non_200_map;
  size_t total_read_len = 0;
  size_t total_write_len = 0;
  string document_length;
  string server_software;
  for (it = conns.begin (); it != conns.end (); it++) {
    non_200_count += it->get_non_200_count ();
    io_error_count += it->get_io_error_count ();
    bad_response_count += it->get_bad_response_count ();
    total_response_count += it->get_total_response_count ();
    connection_refused_count += it->get_connection_refused_count ();
    total_read_len += it->get_total_read_len ();
    total_write_len += it->get_total_write_len ();
    const map <string, int>& tmap (it->get_non_200_map ());
    map <string, int>::const_iterator mit;
    for (mit = tmap.begin (); mit != tmap.end (); mit++) {
      non_200_map[mit->first] = mit->second;
    }
    if (document_length == "")
      document_length = it->get_document_length ();
    if (server_software == "")
      server_software = it->get_server_software ();
  }
  double time_real;
  time_real  = timediff.tv_usec;
  time_real /= 1000000;
  time_real += timediff.tv_sec;
  cout << "\n";
  cout << format_string ("Server Software:           %s\n",
			 server_software.c_str ());
  cout << format_string ("Server Hostname:           %s\n",
			 options["server"].c_str ());
  cout << format_string ("Server Port:               %s\n",
			 options["port"].c_str ());
  cout << format_string ("Request Method:            %s\n",
			 options["method"].c_str ());
  cout << format_string ("Document URI:              %s\n",
			 options["uri"].c_str ());
  cout << format_string ("Document Length:           %s\n",
			 document_length.c_str ());
  cout << format_string ("Time:                      %.3f sec\n", time_real);
  cout << format_string ("\n");
  cout << format_string ("Read Bytes:         %15lu   (%10.3f kb/sec)\n",
	     total_read_len, ((double)total_read_len) / time_real / 1000);
  cout << format_string ("Wrote Bytes:        %15lu   (%10.3f kb/sec)\n",
	     total_write_len, ((double)total_write_len) / time_real / 1000);
  cout << format_string ("Connections:        %15d   "
			 "(%10.3f connections/sec)\n",
			 nconns_done, ((double)nconns_done) / time_real);
  cout << format_string ("Total Responses:    %15d   (%10.3f requests/sec)\n",
	     total_response_count, ((double)total_response_count) / time_real);
  int error_count = non_200_count + io_error_count + bad_response_count
    + connection_refused_count;
  if (error_count > 0) {
    cout << format_string ("Errors:             %15d\n", error_count);
    cout << format_string ("Non 200 Responses:  %15d\n", non_200_count);
    cout << format_string ("IO Error:           %15d\n", io_error_count);
    cout << format_string ("Bad Responses:      %15d\n", bad_response_count);
    cout << format_string ("Connection Refused: %15d\n",
			   connection_refused_count);
    if (non_200_count) {
      map <string, int>::iterator mit;
      for (mit = non_200_map.begin (); mit != non_200_map.end (); mit++) {
	cout << format_string ("Response %s:       %15d\n",
		   mit->first.c_str (), mit->second);
      }
    }
  }
}

int main (int argc, char **argv)
{
  map <string, string> options;
  for (int i = 1; i < argc; i++) {
    char *val;
    val = strchr (argv[i], '=');
    if (val) {
      *val++ = '\0';
    } else {
      val = "";
    }
    options[argv[i]] = val;
  }
  signal (SIGPIPE, SIG_IGN);
  ec_app app (options);
  app.run ();
  return 0;
}
