import pickle

from trial import Trial


class Mnd:
    def __init__(self, filename):
        self.filename = filename
        self.data = None

    def export_data(self, trial_id, db):
        trial_setting = db.get_trial_setting(trial_id)
        trial = Trial(trial_setting)
        subjects = db.get_subjects(trial_id)
        ss = 30
        if len(subjects) > 30:
            ss = len(subjects) + 10
        ui_pool = list(range(ss))
        for subject in subjects:
            if subject[2] in ui_pool:
                ui_pool.remove(subject[2])
        while len(ui_pool) > (ss - len(subjects)):
            ui_pool.pop()
        treatments = db.read_treatments(trial_id)
        groups = []
        for treatment in treatments:
            group = {'name': treatment[1], 'allocation_ratio': treatment[2]}
            groups.append(group)
        factors = db.read_factors(trial_id)
        variables = []
        for factor in factors:
            factor_levels = db.factor_levels(factor[0])
            titles = [factor_level[1] for factor_level in factor_levels]
            factor.append([factor_level[0] for factor_level in factor_levels])
            levels = ','.join(titles)
            variable = {'name': factor[1], 'weight': factor[2], 'levels': levels}
            variables.append(variable)
        allocations = []
        for subject in subjects:
            allocation = {'UI': subject[2]}
            for treatment_index, treatment in enumerate(treatments):
                if treatment[0] == subject[1]:
                    allocation['allocation'] = treatment_index
                    break
            subject_levels = db.get_subject_levels(subject[0])
            levels = []
            for subject_level in subject_levels:
                    levels.append(self.get_factor_level_index(factors, subject_level))
            allocation['levels'] = levels
            allocations.append(allocation)
        if not db.has_preload(trial_id):
            initial_freq_table = 0
        else:
            preload = db.get_preload(trial_id)
            initial_freq_table = [[[preload[(t[0], factor[0], level_id)] for level_id in factor[-1]] for factor in factors] for t in treatments]
        self.data = {'trial_title': trial.title,
                     'trial_description': trial.title,
                     'trial_properties': [],
                     'high_prob': trial.base_prob,
                     'prob_method': trial.prob_method,
                     'distance_measure': trial.dist_method,
                     'ui_pool': ui_pool,
                     'sample_size': ss,
                     'groups': groups,
                     'variables': variables,
                     'allocations': allocations,
                     'initial_freq_table': initial_freq_table}
        fp = open(self.filename, 'wb')
        pickle.dump(self.data, fp, protocol=2)
        fp.flush()
        fp.close()
        return True

    def get_factor_level_index(self, factors, subject_level):
        for factor_index, factor in enumerate(factors):
            if factor[0] == subject_level[0]:
                for factor_level_index, factor_level in enumerate(factor[-1]):
                    if factor_level == subject_level[1]:
                        return factor_level_index

    def data_file_valid(self):
        try:
            fp = open(self.filename, 'rb')
            self.data = pickle.load(fp)
            self.trial_title = self.data['trial_title']
            self.allocations = self.data['allocations']
            self.high_prob = self.data['high_prob']
            self.initial_freq_table = self.data['initial_freq_table']
            self.prob_method = self.data['prob_method']
            self.distance_measure = self.data['distance_measure']
            self.groups = self.data['groups']
            self.variables = self.data['variables']
            return True
        except:
            return False

    def get_detail(self):
        det = []
        det.append('trial_title: {}'.format(self.trial_title))
        treatments = []
        for group in self.groups:
            treatments.append(group['name'])
        det.append('Treatments: ({})'.format(', '.join(treatments)))
        factors = []
        for variable in self.variables:
            factors.append('{}({})'.format(variable['name'], variable['levels']))
        det.append('Factors: [{}]'.format(', '.join(factors)))
        det.append('{} Subjects'.format(len(self.allocations)))
        if self.initial_freq_table:
            det.append('Trial has preload')
        return '\n'.join(det)

    def import_data(self, db):
        trial_id = db.insert_trial(self.trial_title)
        treatments = []
        for group in self.groups:
            treatment_id = db.insert_treatment(trial_id, group['name'])
            treatments.append(treatment_id)
        factors = []
        for variable in self.variables:
            factor_id = db.insert_factor(trial_id, variable['name'])
            factor = [factor_id]
            level_ids = []
            levels = variable['levels'].split(',')
            for level in levels:
                level_id = db.insert_level(trial_id, factor_id, level)
                level_ids.append(level_id)
            factor.append(level_ids)
            factors.append(factor)
        all_numeric = True
        for allocation in self.allocations:
            if str(allocation['UI']).isnumeric():
                continue
            all_numeric = False
            break
        id_value = 0
        for allocation in self.allocations:
            treatment_index = allocation['allocation']
            treatment_id = treatments[treatment_index]
            identifier_value = allocation['UI'] if all_numeric else id_value
            if not all_numeric:
                id_value += 1
            subject_id = db.insert_subject(identifier_value, treatment_id, trial_id)
            levels = allocation['levels']
            subject_levels = []
            for factor_index, level_index in enumerate(levels):
                factor_id = factors[factor_index][0]
                level_id = factors[factor_index][1][level_index]
                subject_levels.append((factor_id, level_id))
            db.insert_subject_levels(trial_id, subject_id, subject_levels)
        if not self.initial_freq_table:
            return trial_id
        preload = {}
        for treatment_index, row in enumerate(self.initial_freq_table):
            t = treatments[treatment_index]
            for factor_index, factor in enumerate(row):
                f = factors[factor_index][0]
                for level_index, count in enumerate(factor):
                    lv = factors[factor_index][1][level_index]
                    preload[(t, f, lv)] = count
        db.save_preload(trial_id, preload)
        return trial_id