import MySQLdb, sys, configparser, time, subprocess, codecs, os

from PyQt5.QtWidgets import (QVBoxLayout, QLineEdit,QGridLayout,
 QApplication, QLabel, QMenu, QWidget, QHBoxLayout,QDialog,
 QAction, QTextEdit,QPushButton, qApp, QMessageBox, QInputDialog,
 QMainWindow, QFrame, QFileDialog, QCheckBox, QProgressBar,
 QListWidget, QCompleter, QComboBox, QTreeView,QAbstractItemView,)
from PyQt5.QtGui import (QIcon, QFont, QColor, QPixmap, QDrag, QPainter,
 QBrush, QPen,QStandardItem,QStandardItemModel)
from PyQt5.QtCore import (QCoreApplication, Qt, QObject, pyqtSignal,
 QMimeData, QAbstractItemModel,QModelIndex)

from main_config import main_Config


class Mysql_Fasttext():
    def __init__(self, project, mincount, dim, epoch, lr, word_ngram, minn, maxn):
        self.with_fasttext(project, mincount, dim, epoch, lr, word_ngram, minn, maxn)


    def with_fasttext(self, project, min_count, dim, epoch, lr, word_ngram, minn, maxn):
#        config=configparser.ConfigParser()
#        config.read('main.ini')
#        mysql_id=config['settings']["id"]
#        mysql_pass=config['settings']["pass"]
#        root=config['settings']["root_path"]
#        using_os=config['settings']["os"]
        mysql_id, mysql_pass, projects, root_path, using_os = main_Config()

        cfg=configparser.ConfigParser()
        file_path = project + ".ini"
        cfg.read(file_path)
        sampling=cfg['project']["sample"]
        labels=cfg['project']["nums_of_label"]

        self.import_from_mysql(project, sampling, labels, mysql_id, mysql_pass, using_os) #mysqlから形態素解析済みのデータを取得しencode_temp.textに保存

        self.to_fasttext(project, labels, min_count, dim,  epoch, lr, word_ngram, minn, maxn, root_path, sampling, mysql_id, mysql_pass, using_os) #encode_temp.textをfasttextに送り分散表現に変換

        encode_temp_path = os.getcwd() + "/encode_temp.txt"
        os.remove(encode_temp_path)

    def import_from_mysql(self, project, sampling, labels, mysql_id, mysql_pass, using_os):
        g = codecs.open("encode_temp.txt","w","utf-8")
        g.write("")
        g.close()
        time.sleep(0.1)

        if int(sampling) == 1:
            if int(labels) > 0:
                sql="select sentence_id, label_name, morpheme from text_table where sampling = 1"
            else:
                sql="select sentence_id, morpheme from text_table where sampling = 1"
        elif int(sampling) == 10 or int(sampling) == 11:
                sql="select sentence_id, label_name, morpheme from text_table where train_test_data = 1"
        else:
            if int(labels) > 0:
                sql="select label_name, morpheme from text_table"
            else:
                sql="select morpheme from text_table"
        #print("a")
        try:
            conn=MySQLdb.connect(db=project, user=mysql_id, passwd=mysql_pass, charset="utf8mb4")
            cursor=conn.cursor()

            time.sleep(0.1)

            cursor.execute(sql)
            all_sentence = cursor.fetchall()

            time.sleep(0.1)

            conn.commit()
            cursor.close()
            conn.close()


            ##################################################################################
            #サンプリング時の処理
            if int(sampling) == 1 or int(sampling) == 10 or int(sampling) == 11:
                if int(labels) > 0:
                    sentences=all_sentence
                    self.sentenceId=[]  #sentence_id
                    all_sentence=[]
                    for n in sentences:
                        self.sentenceId.append(n[0])
                        a=(n[1], n[2])
                        all_sentence.append(a)
                else:
                    sentences=all_sentence
                    self.sentenceId=[]  #sentence_id
                    all_sentence=[]
                    for n in sentences:
                        self.sentenceId.append(n[0])
                        all_sentence.append(n[1])
            else:
                pass

            ##################################################################################

            if int(labels) > 0:
                sentences=all_sentence
                all_sentence=[]
                for n in sentences:
                    label = n[0]
                    text = n[1]
                    label= "__label__" + label
                    la_te= (label, text)
                    all_sentence.append(la_te)
            else:
                pass

            g = codecs.open("encode_temp.txt","w", "utf-8")

            for n in all_sentence:
                #print(n)
                n=str(n)
                n=n.replace("(' ","").replace("',)","")
                n=n.replace("', '", ", ").replace("('", "").replace("')","").replace("  "," ").replace("(","").replace("', '", ", ").replace(", '", ", ")
                #print(n)
                if using_os == "nt":
                    g.write(n)
                    g.write("\r\n")
                elif  using_os == "posix":
                    g.write(n)
                    g.write("\n")

            g.close()

        except Exception as e:
            print(e)

    def to_fasttext(self, project, labels, min_count, dim, epoch, lr, word_ngram, minn, maxn, root, sampling, mysql_id, mysql_pass, using_os):

        #OSで分岐
        #windowsのsupervisedはvecファイルを作らないので、先にskipgramでvecファイルを作成し、その後にsupevisedでbinファイルを作成→skipgramではbin内に文章情報がないと思われget-sentence-vectorが働かない
#" -epoch " + str(epoch) + " -lr " + str(le) + " -word_ngram " + word_ngram + " -minn " + minn + " -maxn " + maxn +
        if using_os == "nt":
            if int(labels) > 0: #ラベルありの場合の分散表現の学習
                cmd = root + "/fastText/fasttext.exe supervised -minCount " + str(min_count) + " -dim " + str(dim) + " -epoch " + str(epoch) + " -lr " + str(lr) + " -wordNgrams " + word_ngram + " -minn " + minn + " -maxn " + maxn + " -input " + root + "/encode_temp.txt -output " + root + "/fastText/" + project
                #cmd = root + "/fasttext/fasttext.exe supervised -minCount " + str(min_count) + " -dim " + str(dim) + " -input " + root + "/encode_temp.txt -output " + root + "/fasttext/" + project
                print(cmd)
            elif int(labels) == 0:
                #cmd = root + "/fasttext/fasttext.exe skipgram -minCount " + str(min_count) + " -dim " + str(dim) + " -input " + root + "/encode_temp.txt -output " + root + "/fasttext/" + project
                cmd = root + "/fastText/fasttext.exe skipgram -minCount " + str(min_count) + " -dim " + str(dim) + " -epoch " + str(epoch) + " -lr " + str(lr) + " -wordNgrams " + word_ngram + " -minn " + minn + " -maxn " + maxn + " -input " + root + "/encode_temp.txt -output " + root + "/fastText/" + project
            else:
                message=QMessageBox.information(self, "よくわからないエラーが発生", QMessageBox.Ok)
                #print("よくわからないエラーが発生")
            #cmd = root + "/fasttext/fasttext.exe skipgram -minCount " + str(min_count) + " -dim " + str(dim) + " -input " + root + "/encode_temp.txt -output " + root + "/fasttext/" + project
            #popen = subprocess.Popen(cmd, shell=True)
            #popen.wait()    #fasttextの変換処理が終了するまで待機

            #time.sleep(0.1)

            #cmd = root + "/fasttext/fasttext.exe supervised -minCount " + str(min_count) + " -dim " + str(dim) + " -input " + root + "/encode_temp.txt -output " + root + "/fasttext/" + project
            #popen = subprocess.Popen(cmd, shell=True)
            #popen.wait()    #fasttextの変換処理が終了するまで待機
            #time.sleep(0.1)

        elif using_os == "posix":
            if int(labels) > 0: #ラベルありの場合の分散表現の学習
                cmd = root + "/fastText/fasttext supervised -minCount " + str(min_count) + " -dim " + str(dim) + " -epoch " + str(epoch) + " -lr " + str(lr) + " -wordNgrams " + word_ngram + " -minn " + minn + " -maxn " + maxn + " -input " + root + "/encode_temp.txt -output " + root + "/fastText/" + project
                #print(cmd)
            elif int(labels) == 0:  #ラベル無しの場合の分散表現の学習
                cmd = root + "/fastText/fasttext skipgram -minCount " + str(min_count) + " -dim " + str(dim) + " -epoch " + str(epoch) + " -lr " + str(lr) + " -wordNgrams " + word_ngram + " -minn " + minn + " -maxn " + maxn + " -input " + root + "/encode_temp.txt -output " + root + "/fastText/" + project
            #minCountは最小出現回数何回の単語を分散表現の計算に使用するか
            #dimは分散表現に使用する次元数
            else:
                message=QMessageBox.information(self, "よくわからないエラーが発生", QMessageBox.Ok)
                #print("よくわからないエラーが発生")

        else:
            #print("よくわからないエラーが発生")
            message=QMessageBox.information(self, "よくわからないエラーが発生", QMessageBox.Ok)

        #print(cmd)

        popen = subprocess.Popen(cmd, shell=True)
        popen.wait()    #fasttextの変換処理が終了するまで待機

        sentence_vectors = self.get_sentence_vector(project, root)

        if int(sampling) == 0:
            self.sentenc_vectors_sampling_0_mysql(project, sentence_vectors, mysql_id, mysql_pass)
        elif int(sampling) == 1 or int(sampling) == 10 or int(sampling) == 11:
            self.sentenc_vectors_sampling_1_mysql(project, sentence_vectors, mysql_id, mysql_pass)
        else:
            message=QMessageBox.information(self, "よくわからないエラーが発生", QMessageBox.Ok)
            #print("よくわからないエラーが発生")

    def sentenc_vectors_sampling_1_mysql(self, project, sentence_vectors, mysql_id, mysql_pass):
        result=[]
        for (no, vector) in zip(self.sentenceId, sentence_vectors):
            vector_and_number=(vector, no)
            result.append(vector_and_number)
        #print(result)

        self.input_mysql(project, result, mysql_id, mysql_pass)


    def sentenc_vectors_sampling_0_mysql(self, project, sentence_vectors, mysql_id, mysql_pass):
        result=[]
        for no, vector in enumerate(sentence_vectors):    #sentence_idをenumerateで作成し
            no = no + 1
            vector_and_number=(vector, no)
            result.append(vector_and_number)
        #print(result)

        self.input_mysql(project, result, mysql_id, mysql_pass)    #mysqlに保存


    def input_mysql(self, project, result, mysql_id, mysql_pass):
        try:
            conn=MySQLdb.connect(db=project, user=mysql_id, passwd=mysql_pass, charset="utf8mb4")
            cursor=conn.cursor()

            time.sleep(0.1)

            sql="update text_table set sentence_vector = %s where sentence_id = %s"

            cursor.executemany(sql, result)

            time.sleep(0.1)

            conn.commit()
            cursor.close()
            conn.close()

        except Exception as e:
            print(e)

    def get_sentence_vector(self, project, root):
        cmd = root + "/fastText/fasttext print-sentence-vectors " + root + "/fastText/" + project + ".bin < " + root + "/encode_temp.txt"
        #print(cmd)
        proc = subprocess.Popen(cmd, shell = True, stdin = subprocess.PIPE, stdout = subprocess.PIPE, stderr = subprocess.PIPE)
        stdout_data, stderr_data = proc.communicate() #処理実行を待つ(†1)
        #print(stdout_data)  #標準出力の確認
        #print(stderr_data)  #標準エラーの確認
        proc.wait()

        stdout_data = stdout_data.decode("utf-8")
        stdout_data = stdout_data.split("\n")

        sentence_vectors=[]
        for n in stdout_data:
            if n == "":
                pass
            else:
                sentence_vectors.append(n)

        #print(len(sentence_vectors))

        return sentence_vectors


if __name__=="__main__":
    project = "tttt"
    #project="only"
    mincount=1
    dim=100
    root="/home/jabba/win20180129"
    #Mysql_Fasttext(project, mincount, dim)
    Mysql_Fasttext(project, mincount, dim).get_sentence_vector(project, root)
