import time, MySQLdb, random, configparser, math

from write_project_ini import write_pro_ini
from main_config import main_Config


def learn_sampling(project, rate, sampling):

#    config=configparser.ConfigParser()
#    config.read('main.ini')
#    mysql_id=config['settings']["id"]
#    mysql_pass=config['settings']["pass"]
    mysql_id, mysql_pass, projects, root_path, using_os = main_Config()

    train_test_data_reset(project, mysql_id, mysql_pass)   #サンプリングカラムの初期化

    each_label_sampling(project, mysql_id, mysql_pass, rate, sampling)

def each_label_sampling(project, mysql_id, mysql_pass, rate, sampling):
    if int(sampling) == 1 or int(sampling) == 0:
        if int(sampling) == 1:
            sql="select label_name, count(label_name) from text_table where sampling = 1 group by label_name"
        elif int(sampling) ==0:
            sql="select label_name, count(label_name) from text_table group by label_name"

        try:
            conn=MySQLdb.connect(db=project, user=mysql_id, passwd=mysql_pass, charset="utf8mb4")
            cursor=conn.cursor()

            cursor.execute(sql)
            nums=list(cursor.fetchall())

            time.sleep(0.1)
            conn.commit()
            cursor.close()
            conn.close()

            time.sleep(0.1)

            rate = (int(rate) / 100)

            result=[]
            for n in nums:
                label = n[0]
                count = n[1]
                quotient = math.floor(count * rate)
                remainder = count - quotient
                back=(label, count, quotient, remainder)
                result.append(back)

                import_sampling(project, mysql_id, mysql_pass, sampling, label, quotient, remainder)

            return result

        except Exception as e:
            print(e)

    else:
        print("エラー発生")

def import_sampling(project, mysql_id, mysql_pass, sampling, label, quotient, remainder):
    if int(sampling) == 0:
        sql="select sentence_id from text_table where label_name = '" + label + "'"
    else:
        sql="select sentence_id from text_table where sampling = 1 and label_name = '" + label + "'"

    try:
        conn=MySQLdb.connect(db=project, user=mysql_id, passwd=mysql_pass, charset="utf8mb4")
        cursor=conn.cursor()

        cursor.execute(sql)
        nums=list(cursor.fetchall())

        time.sleep(0.1)
        conn.commit()
        cursor.close()
        conn.close()

        time.sleep(0.1)

        ids=[]
        for n in nums:
            ids.append(n[0])

        teacher = teacher_sampling(ids, quotient)   #教師データのサンプリング

        teachers=[]
        for n in teacher:
            ids.remove(n)   #分析対象データ中の評価用データの作成
            a=(1,n) #１がmysql上のの教師データの分類
            teachers.append(a)

        time.sleep(0.1)

        evalus=[]
        for n in ids:
            a=(2,n)
            evalus.append(a)

        insert_sampling(project, mysql_id, mysql_pass, teachers)

        insert_sampling(project, mysql_id, mysql_pass, evalus)

    except Exception as e:
        print(e)

def insert_sampling(project, mysql_id, mysql_pass, sentence_id):
    try:
        conn=MySQLdb.connect(db=project, user=mysql_id, passwd=mysql_pass, charset="utf8mb4")
        cursor=conn.cursor()

        sql="update text_table set train_test_data = %s where sentence_id = %s"
        cursor.executemany(sql,sentence_id)
        time.sleep(0.1)

        conn.commit()
        cursor.close()
        conn.close()

    except Exception as e:
        print(e)

def teacher_sampling(nums, count):
    sa_nums=random.sample(nums, count)#第二引数にサンプル数を記入
    return sa_nums


def train_test_data_reset(project, mysql_id, mysql_pass):
    try:
        conn=MySQLdb.connect(db=project, user=mysql_id, passwd=mysql_pass, charset="utf8mb4")
        cursor=conn.cursor()

        sql="update text_table set train_test_data = 0"
        cursor.execute(sql)
        time.sleep(0.1)

        sql="update text_table set sentence_vector = NULL"
        cursor.execute(sql)

        conn.commit()
        cursor.close()
        conn.close()

    except Exception as e:
        print(e)

if __name__=="__main__":
    project="ok"
    #project="ok_table"
    count=90
    sampling="0"
    learn_sampling(project, count, sampling)
