# -*- encoding: utf-8 -*-
#   Copyright 2008 Agile42 GmbH, Berlin (Germany)
#   Copyright 2007 Andrea Tomasini <andrea.tomasini_at_agile42.com>
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#   Author: Andrea Tomasini <andrea.tomasini_at_agile42.com>

from datetime import datetime, timedelta
from random import randint
from time import mktime

from svn import fs, repos, core
from trac.env import Environment
from trac.perm import DefaultPermissionPolicy, PermissionCache, PermissionSystem
from trac.test import EnvironmentStub, Mock
from trac.ticket import Milestone
from trac.util.datefmt import localtz, parse_date, to_timestamp, to_datetime
from trac.versioncontrol.api import RepositoryManager
from trac.web.api import Cookie, Href

from agilo.api.controller import ValueObject
from agilo.scrum import BacklogModelManager, SprintModelManager, \
    TeamModelManager, TeamMemberModelManager
from agilo.scrum.backlog.burndown import RemainingTime
from agilo.ticket.model import AgiloTicketModelManager
from agilo.utils import Key, Status, BacklogType, Type
from agilo.utils.compat import exception_to_unicode
from agilo.utils.sorting import SortOrder
from agilo.utils.config import AgiloConfig


__all__ = ['BetterEnvironmentStub', 'TestEnvHelper']

# the EnvironmentStub needs that the Agilo's Roles are visible to load them
import agilo.utils.permissions


class BetterEnvironmentStub(EnvironmentStub):
    
    # This is a patch similar to http://trac.edgewall.org/ticket/8591
    # even if the patch above would go into trac 0.11.x, we still need to 
    # support older trac versions
    def __init__(self, default_data=False, enable=None):
        super(BetterEnvironmentStub, self).__init__(default_data=default_data, enable=enable)
        if enable is not None:
            self.config.set('components', 'trac.*', 'disabled')
        for name_or_class in enable or ():
            config_key = self.normalize_configuration_key(name_or_class)
            self.config.set('components', config_key, 'enabled')
    
    def normalize_configuration_key(self, name_or_class):
        name = name_or_class
        if not isinstance(name_or_class, basestring):
            name = name_or_class.__module__ + '.' + name_or_class.__name__
        return name.lower()
    
    def is_component_enabled(self, cls):
        return Environment.is_component_enabled(self, cls)
    
    # See trac ticket: http://trac.edgewall.org/ticket/7619, it has been fixed 
    # in 0.11.2, but we still tests till 0.11.1 because it is still the
    # default for some linux distributions.
    def get_known_users(self, db=None):
        return self.known_users




class TestEnvHelper(object):
    """Helper to create an environment give the path. Used for testing"""
    def __init__(self, env=None, strict=False, enable=()):
        if env is None:
            self.env = BetterEnvironmentStub(default_data=True, 
                                             enable=['trac.*', 'agilo.*'] + list(enable))
            # Set the connection type in the config as memory
            # Sent patch to trac #7208 waiting for commit
            #self.env.config.set('trac', 'database', 'sqlite::memory:')
            self.env.config.set('trac', 'permission_policies', 'AgiloPolicy, DefaultPermissionPolicy, LegacyAttachmentPolicy')
            if strict:
                self.env.config.set('ticket', 'restrict_owner', 'true')
        else:
            self.env = env
        self.env_path = self.env.path
        self.objects = list()
        self.files = list()
        self._ticket_counter = 0
        # Initialize the config and the database
        # Avoid recursive imports - TestEnvHelper must not trigger anything
        from agilo.init import AgiloInit
        ai = AgiloInit(self.env)
        db = self.env.get_db_cnx()
        if ai.environment_needs_upgrade(db):
            ai.upgrade_environment(db)
            
        # commit changes to the db
        db.commit()
        self.svn_repos = None
        try:
            repo_path = RepositoryManager(self.env).repository_dir
            self.svn_repos = repos.open(repo_path)
        except:
            #No repo configured
            pass
        
    def get_env(self):
        """Returns the created environment"""
        return self.env
        
    def get_env_path(self):
        """Returns the trac environment path"""
        return self.env_path
    
    def _set_sprint_date_normalization(self, enabled):
        config = AgiloConfig(self.env)
        config.change_option('sprints_can_start_or_end_on_weekends', not enabled,
                             section=AgiloConfig.AGILO_GENERAL)
        config.save()
        config.reload()
    
    def enable_sprint_date_normalization(self):
        self._set_sprint_date_normalization(True)
        assert AgiloConfig(self.env).sprints_can_start_or_end_on_weekends == False
    
    def disable_sprint_date_normalization(self):
        self._set_sprint_date_normalization(False)
        assert AgiloConfig(self.env).sprints_can_start_or_end_on_weekends == True
    
    def create_milestone(self, name, due=None, duration=20):
        """
        Creates a milestone with the given name and due
        date, the latter should be a datetime object
        """
        conn = self.env.get_db_cnx()
        m = Milestone(self.env, db=conn)
        m.name = name
        if due is not None and isinstance(due, datetime):
            dueo = due.toordinal() + duration
            m.due = mktime(datetime.fromordinal(dueo).timetuple())
        try:
            m.insert()
            conn.commit()
        except:
            conn.rollback()
            # The milestone already exists fetch it
            m = Milestone(self.env, name=name, db=conn)
        return m
    
    def delete_milestone(self, name):
        """Deletes the given milestone"""
        conn = self.env.get_db_cnx()
        m = Milestone(self.env, name=name, db=conn)
        m.delete(db=conn)
        
    def list_milestone_names(self):
        """Returns a list of all the milestone names in the env"""
        conn = self.env.get_db_cnx()
        names = []
        sql_query = "SELECT name FROM milestone"
        cursor = conn.cursor()
        try:
            cursor.execute(sql_query)
            rows = cursor.fetchall()
            for name, in rows:
                names.append(name)
        except:
            conn.rollback()
        return names
    
    def delete_all_milestones(self):
        """Delete all the milestones in the environment"""
        names = self.list_milestone_names()
        for name in names:
            self.delete_milestone(name)
    
    def generate_remaining_time_data(self, task, start_date, end_date=None, initial=12):
        """Generates remaining time data for the given task starting
        at the start_date, and ending at the end_date or today"""
        if task.get_type() == Type.TASK:
            if not end_date:
                end_date = parse_date('now')
            
            last = initial
            if task.has_owner:
                # set it to accepted
                task[Key.STATUS] = Status.ACCEPTED
                rt = RemainingTime(self.env, task)
                day = start_date
                one_day = timedelta(days=1)
                while day <= end_date:
                    last -= randint(0,4)
                    if last < 0:
                        last = 0
                    rt.set_remaining_time(last, day)
                    day += one_day
            # Sets the last into the task as well
            task[Key.REMAINING_TIME] = str(last)
            if last == 0:
                task[Key.STATUS] = Status.CLOSED
            task.save_changes(author=task[Key.OWNER], comment='Updated time...')
    
    def create_sprint(self, name, start=None, end=None, duration=20, 
                      milestone=None, team=None):
        """Creates a Sprint for the given milestone, if doesn't exists, first
        it creates a Milestone"""
        # We need UTC shifted timezone
        now = parse_date('now')
        # If the start day is set to today, the sprint will 
        # normalize it to 9:00am of the start day and all the tests
        # will fail, till 9:00am in the morning...
        if start is None:
            # we set hours to 0 so will be normalized to 9am at any
            # time of the day, when running tests.
            start = (now - timedelta(days=3)).replace(hour=0)
        if milestone is None:
            milestone = self.create_milestone('Milestone for %s' % name)
        # It should automatically load the existing Sprint if already there
        if isinstance(milestone, Milestone):
            milestone = milestone.name
        
        s = SprintModelManager(self.env).create(name=name, 
                                                start=start,
                                                end=end,
                                                duration=duration,
                                                milestone=milestone)
        assert s != None
        if team is not None:
            if isinstance(team, basestring):
                team = self.create_team(name=team)
            s.team = team
        s.save()
        return s
    
    def delete_sprint(self, name):
        """Deletes the given Sprint from the environment"""
        smm = SprintModelManager(self.env)
        s = smm.get(name=name)
        smm.delete(s)
        
    def list_sprint_names(self):
        """Returns a list of all the sprint names"""
        smm = SprintModelManager(self.env)
        return [s.name for s in smm.select()]
    
    def delete_all_sprints(self):
        """Deletes all the sprints in the environment"""
        names = self.list_sprint_names()
        for name in names:
            self.delete_sprint(name)
    
    def create_team(self, name='Team'):
        """Creates and return a team object, if already existing just returns it"""
        tmm = TeamModelManager(self.env)
        team = tmm.get(name=name)
        if not team:
            team = tmm.create(name=name)
        return team
    
    def list_team_names(self):
        """Returns a list of existing team names"""
        tmm = TeamModelManager(self.env)
        return [t.name for t in tmm.select()]
    
    def create_member(self, name, team=None):
        """Creates a team member for the given team with the given name"""
        if team is not None and isinstance(team, basestring):
            team = self.create_team(team)
        tmmm = TeamMemberModelManager(self.env)
        member = tmmm.get(name=name, team=team)
        if not member:
            member = tmmm.create(name=name, team=team)
        return member
    
    def create_backlog(self, name='Performance Backlog', 
                       num_of_items=10, b_type=BacklogType.GLOBAL, 
                       ticket_types=[Type.REQUIREMENT, Type.USER_STORY], 
                       sorting_keys=[(Key.BUSINESS_VALUE, SortOrder.DESCENDING)],
                       scope=None):
        """Creates a Backlog with the given parameters and returns it"""
        # Characteristic properties
        ticket_custom = AgiloConfig(self.env).get_section(AgiloConfig.TICKET_CUSTOM)
        char_props = {Type.REQUIREMENT: [(Key.BUSINESS_VALUE, 
                                          ticket_custom.get("%s.options" % Key.BUSINESS_VALUE).split('|'))],
                      Type.USER_STORY: [(Key.STORY_PRIORITY, 
                                         ticket_custom.get("%s.options" % Key.STORY_PRIORITY).split('|')),
                                        (Key.STORY_POINTS,
                                         ticket_custom.get("%s.options" % Key.STORY_POINTS).split('|'))],
                      Type.TASK: [(Key.REMAINING_TIME, ['12', '8', '4', '6', '2', '0'])],
                      Type.BUG: [(Key.PRIORITY, ['minor', 'major', 'critical'])]}
        # creates the specified number of tickets
        last = None
        for i in range(num_of_items):
            t_type = ticket_types[randint(0, len(ticket_types) - 1)]
            t_props = dict([(prop_name, values[randint(0, len(values) - 1)]) for \
                            prop_name, values in char_props[t_type]])
            if scope and BacklogType.LABELS.get(b_type) in \
                    AgiloConfig(self.env).TYPES.get(t_type):
                # Set the scope to the ticket
                t_props[BacklogType.LABELS.get(b_type)] = scope
            t_props[Key.SUMMARY] = 'Agilo Ticket #%d' % i
            
            actual = self.create_ticket(t_type, props=t_props)
#            print "Ticket(%s): %s => %s (Backlog: %s)" % \
#                   (actual[Key.STATUS], actual, t_props, 
#                    BacklogType.LABELS.get(b_type))
            if last:
                if ticket_types.index(last.get_type()) > ticket_types.index(actual.get_type()):
                    assert actual.link_to(last)
                else:
                    assert last.link_to(actual)
                last = actual
        # Creates the backlog
        bmm = BacklogModelManager(self.env)
        b = bmm.create(name=name,
                       ticket_types=ticket_types,
                       sorting_keys=sorting_keys,
                       b_type=b_type,
                       scope=scope)
        # We need to reload with scope if a scoped backlog
        if b_type != BacklogType.GLOBAL:
            b = bmm.get(name=name, scope=scope, reload=True)
            b.save()
            b.reload()
        return b
        
    def create_file(self, file_name, content, author, comment):
        """
        Creates a file in the SVN repository with the given
        name and content (text). Returns the committed revision
        """
        assert self.svn_repos is not None, "SVN repository not set..."
        # Get an SVN file system pointer
        fs_ptr = repos.fs(self.svn_repos)
        rev = fs.youngest_rev(fs_ptr)
        # Create and SVN transaction
        txn = fs.begin_txn(fs_ptr, rev)
        txn_root = fs.txn_root(txn)
        # Create a file in the root transaction
        fs.make_file(txn_root, file_name)
        stream = fs.apply_text(txn_root, file_name, None)
        core.svn_stream_write(stream, "%s\n" % content)
        core.svn_stream_close(stream)
        # Now set the properties svn:log and svn:author to
        # the newly created node (file)
        fs.change_txn_prop(txn, 'svn:author', author)
        fs.change_txn_prop(txn, 'svn:log', comment)
        # Commit the transaction
        fs.commit_txn(txn)
        # Add teh file to the list of created files
        self.files.append(file_name)
        # Returns therevision number
        return rev + 1
        
    def delete_file(self, file_name):
        """Deletes the given file from the repository"""
        assert self.svn_repos is not None, "SVN repository not set..."
        # Get an SVN file system pointer
        fs_ptr = repos.fs(self.svn_repos)
        rev = fs.youngest_rev(fs_ptr)
        # Create and SVN transaction
        txn = fs.begin_txn(fs_ptr, rev)
        txn_root = fs.txn_root(txn)
        # Create a file in the root transaction
        fs.delete(txn_root, file_name)
        # Commit the transaction
        fs.commit_txn(txn)
        
    def retrieve_file(self, file_name, rev=None):
        """
        Retrieves the given file name, at the specified revision or 
        the latest available from the SVN repository
        """
        assert self.svn_repos is not None, "SVN repository not set..."
        # Get an SVN file system pointer
        fs_ptr = repos.fs(self.svn_repos)
        if rev is None:
            rev = fs.youngest_rev(fs_ptr)
        root = fs.revision_root(fs_ptr, rev)
        stream = fs.file_contents(root, file_name)
        svn_file = core.Stream(stream)
        core.svn_stream_close(stream)
        return svn_file
        
    def create_ticket(self, t_type, props=None):
        """Utility to create a ticket of the given type"""
        self._ticket_counter += 1
        
        t = dict()
        t['t_type'] = t_type
        t[Key.SUMMARY] = u'%s n.%s' % (t_type.title(), self._ticket_counter)
        t[Key.DESCRIPTION] = u'Description for %s' % t_type
        # Not mandatory, but if not set the status is None and all the query
        # with status != 'closed' will fail.
        t[Key.STATUS] = Status.NEW
        if props and isinstance(props, dict):
            for key, value in props.items():
                t[key] = value
                #print "Key:value:", repr(key), value
        ticket = AgiloTicketModelManager(self.env).create(**t)
        assert ticket.id is not None, "Ticket creation failed!"
        assert ticket.exists
        self.objects.append(ticket)
        return ticket
    
    def delete_ticket(self, t_id):
        """Deletes the ticket with the given ticket id"""
        atm = AgiloTicketModelManager(self.env)
        ticket = atm.get(tkt_id=t_id)
        try:
            atm.delete(ticket)
        except Exception, e:
            print exception_to_unicode(e)
        
    def delete_all_tickets(self):
        """Delete all the tickets in the environment"""
        atm = AgiloTicketModelManager(self.env)
        tickets = atm.select()
        for t in tickets:
            try:
                atm.delete(t)
            except Exception, e:
                print exception_to_unicode(e)
        
    def load_ticket(self, ticket=None, t_id=None):
        """
        Utility method to load a ticket from trac. Used to check
        committed changes
        """
        assert ticket or t_id, "Supply either a ticket or and id"
        if ticket:
            t_id = ticket.id
        tm = AgiloTicketModelManager(self.env)
        tm.get_cache().invalidate(key_instance=((t_id,), None))
        t = tm.get(tkt_id=t_id)
        return t

    def delete_created_tickets(self):
        """Deletes all the tickets created by the helper"""
        for obj in self.objects:
            obj.delete()
            self._ticket_counter -= 1
    
    def delete_files(self):
        """Deletes all the files created by the helper"""
        if self.svn_repos is not None:
            # Get an SVN file system pointer
            fs_ptr = repos.fs(self.svn_repos)
            rev = fs.youngest_rev(fs_ptr)
            # Create and SVN transaction
            txn = fs.begin_txn(fs_ptr, rev)
            txn_root = fs.txn_root(txn)
            # Create a file in the root transaction
            for svn_file in self.files:
                fs.delete(txn_root, svn_file)
            # Commit the transaction
            fs.commit_txn(txn)
        
    def cleanup(self):
        """Delete all the tickets and all the file created"""
        self.delete_files()
        self.delete_created_tickets()
        self.enable_sprint_date_normalization()
    
    def purge_ticket_history(self, ticket):
        db = self.env.get_db_cnx()
        cursor = db.cursor()
        ticket_id = getattr(ticket, 'id', None)
        if ticket_id == None:
            ticket_id = int(ticket)
        sql = 'DELETE FROM ticket_change WHERE ticket=%d'
        cursor.execute(sql % ticket_id)
    
    def emulate_login(self, username, when=None):
        """Emulates a login for the given username, by setting an entry in the
        session table, if when is specified will be also set the datetime of 
        the login to when, otherwise to now"""
        if when is None:
            when = to_datetime(None)
        db = self.env.get_db_cnx()
        try:
            cursor = db.cursor()
            cursor.execute("SELECT sid FROM session WHERE sid='%s'" % username)
            last_visit = to_timestamp(when)
            if cursor.fetchone():
                cursor.execute("UPDATE session SET last_visit=%s WHERE sid='%s'" % \
                               (last_visit, username))
            else:
                cursor.execute("INSERT INTO session (sid, authenticated, last_visit) " \
                               "VALUES ('%s', 1, %s)" % (username, last_visit))
            db.commit()
        except Exception, e:
            db.rollback()
            assert False, "Unable to complete login for user: %s (%)" % \
                (username, exception_to_unicode(e))
            
    def move_changetime_to_the_past(self, tickets):
        """Trac has some tables which use the time as primary key but only with 
        a 'second' precision so you can't save a ticket twice in a second. This
        is very annoying for unit tests so this method can just reset the times
        for a the given tickets to the past."""
        # Please note that while this works great for the same connection pool
        # (e.g. for unit tests) the method did not work for me in a functional 
        # test (svn_hooks_test). I found that trac always read the old data
        # no matter what I did to put them into the database...
        db = self.env.get_db_cnx()
        cursor = db.cursor()
        for ticket in tickets:
            # This should work for ticket ids too so that it is easier to use 
            # this method from functional tests, too.
            ticket_id = getattr(ticket, 'id', None)
            if ticket_id == None:
                ticket_id = int(ticket)
            sql = 'UPDATE ticket_change SET time=time - 2 WHERE ticket=%d'
            cursor.execute(sql % ticket_id)
    
    def mock_request(self, username='anonymous', **kwargs):
        response = ValueObject(headers=dict(), body='')
        
        perm = PermissionCache(self.env, username)
        attributes = dict(args=dict(), tz=localtz, perm=perm, method='GET',
                          path_info='/')
        attributes.update(kwargs)
        return Mock(authname=username, base_path=None, href=Href('/'),
                    chrome=dict(warnings=[], notices=[]),
                    incookie=Cookie(), outcookie=Cookie(), 
                    response=response,
                    
                    send_response=lambda code: response.update({'code': code}),
                    send_header=lambda name, value: response['headers'].update({name: value}),
                    end_headers=lambda: None,
                    write=lambda string: ''.join((response['body'], string)),
                    
                    **attributes)
    
    def grant_permission(self, username, action):
        # DefaultPermissionPolicy will cache permissions for 5 seconds so we 
        # need to reset the cache
        DefaultPermissionPolicy(self.env).permission_cache = {}
        PermissionSystem(self.env).grant_permission(username, action)
        assert action in PermissionSystem(self.env).get_user_permissions(username)

