# -*- coding: utf-8 -*-
#
#  sstplib.py - an SSTP library module in Python
#  Copyright (C) 2001, 2002 by Tamito KAJIYAMA
#  Copyright (C) 2002, 2003 by MATSUMURA Namihiko <nie@counterghost.net>
#  Copyright (C) 2002-2013 by Shyouzou Sugitani <shy@users.sourceforge.jp>
#
#  This program is free software; you can redistribute it and/or modify it
#  under the terms of the GNU General Public License (version 2) as
#  published by the Free Software Foundation.  It is distributed in the
#  hope that it will be useful, but WITHOUT ANY WARRANTY; without even the
#  implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
#  PURPOSE.  See the GNU General Public License for more details.
#

import email
import re
import select
import socket
import socketserver
import logging
import time


class SSTPServer(socketserver.TCPServer):

    allow_reuse_address = True


class AsynchronousSSTPServer(SSTPServer):

    def handle_request(self):
        r, w, e = select.select([self.socket], [], [], 0)
        if not r:
            return
        SSTPServer.handle_request(self)


class BaseSSTPRequestHandler(socketserver.StreamRequestHandler):

    responses = {
        200: 'OK',
        204: 'No Content',
        210: 'Break',
        400: 'Bad Request',
        408: 'Request Timeout',
        409: 'Conflict',
        420: 'Refuse',
        501: 'Not Implemented',
        503: 'Service Unavailable',
        510: 'Not Local IP',
        511: 'In Black List',
        512: 'Invisible',
        }
    re_requestsyntax = re.compile('^([A-Z]+) SSTP/([0-9]\\.[0-9])$')

    def parse_headers(self, fp):
        headers = []
        while True:
            line = fp.readline()
            headers.append(line)
            if line in [x.encode('ascii') for x in ('\r\n', '\n', '')]:
                break
        hbytes = ''.encode('ascii').join(headers)
        message = email.message_from_bytes(hbytes)
        charset = message.get('Charset', 'Shift_JIS')
        message = email.message_from_string(hbytes.decode(charset))
        return message

    def parse_request(self, requestline):
        requestline = str(requestline, 'Shift_JIS')
        if requestline.endswith('\r\n'):
            requestline = requestline[:-2]
        elif requestline.endswith('\n'):
            requestline = requestline[:-1]
        self.requestline = requestline
        match = self.re_requestsyntax.match(requestline)
        if not match:
            self.requestline = '-'
            self.send_error(400, 'Bad Request {0}'.format(repr(requestline)))
            return 0
        self.command, self.version = match.groups()
        self.headers = self.parse_headers(self.rfile)
        return 1

    def handle(self):
        self.error = self.version = None
        if not self.parse_request(self.rfile.readline()):
            return
        name = 'do_{0}_{1}_{2}'.format(self.command, self.version[0], self.version[2])
        if not hasattr(self, name):
            self.send_error(
                501,
                'Not Implemented ({0}/{1})'.format(self.command, self.version))
            return
        method = getattr(self, name)
        method()

    def send_error(self, code, message=None):
        self.error = code
        self.log_error(message or self.responses[code])
        self.send_response(code, self.responses[code])

    def send_response(self, code, message=None):
        self.log_request(code, message)
        self.wfile.write('SSTP/{0} {1:d} {2}\r\n\r\n'.format(
                self.version or '1.0', code, self.responses[code]).encode('Shift_JIS'))

    def log_error(self, message):
        logging.error('[{0}] {1}\n'.format(self.timestamp(), message))

    def log_request(self, code, message=None):
        if self.requestline == '-':
            request = self.requestline
        else:
            request = ''.join(('"', self.requestline, '"'))
        logging.info('{0} [{1}] {2} {3:d} {4}\n'.format(
                self.client_hostname(), self.timestamp(),
                request, code, message or self.responses[code]))

    def client_hostname(self):
        try:
            host, port = self.client_address
        except:
            return 'localhost'
        try:
            hostname, aliaslist, ipaddrlist = socket.gethostbyaddr(host)
        except socket.error:
            hostname = host
        return hostname

    month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
                   'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']

    def timestamp(self):
        t = time.localtime(time.time())
        m = self.month_names[t[1] - 1]
        return '{0:02d}/{1}/{2:d}:{3:02d}:{4:02d}:{5:02d} {6:+05d}'.format(
            t[2], m, t[0], t[3], t[4], t[5], -time.timezone // 36)


def test(ServerClass = SSTPServer,
         HandlerClass = BaseSSTPRequestHandler,
         port = 9801):
    sstpd = ServerClass(('', port), HandlerClass)
    print('Serving SSTP on port {0:d} ...'.format(port))
    print('Allow reuse address: {0:d}'.format(
        sstpd.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)))
    sstpd.serve_forever()

if __name__ == '__main__':
    test()
