#!/usr/bin/env python
#
# Copyright notice:
# Copyright CERN, 2016.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import logging
import os
import sys
import traceback
from datetime import datetime
from optparse import OptionParser
from sqlalchemy import create_engine
from sqlalchemy.exc import DatabaseError

from fts3.util.config import fts3_config_load

log = logging.getLogger(__name__)


class BaseRotate(object):
    """
    Common logic for rotations
    """

    def _run_query(self, always, query, *args):
        """
        Run the query if not running with dry-run. Run it anyway if always is true.
        """
        if not self.opts.dry_run or always:
            log.debug(query)
            return self.engine.execute(query, *args)
        else:
            log.debug(query % args + ' (DRY)')
            return None

    def _lock(self):
        """
        Lock the tables so no one uses them
        """
        log.warning('Lock not implemented')

    def _unlock(self):
        """
        Unlock the tales
        """
        log.warning('Unlock not implemented')

    def _get_file_entries(self):
        """
        Return how many entries there are in t_file_backup
        """
        raise Exception('Not implemented: _get_file_entries')

    def _create_as(self, table, like):
        """
        Create placeholder table
        """
        raise Exception('Not implemented: _create_as')

    def _bulk_rename(self, tables):
        """
        Bulk table renaming
        """
        raise Exception('Not implemented: _bulk_rename')

    def _rotate_tables(self):
        """
        Rename all archive tables and create new ones
        """
        for table in ['t_job', 't_file', 't_dm']:
            log.info('Creating placeholder for %s_backup' % table)
            self._create_as('%s_backup_future' % table, table)

        suffix = datetime.utcnow().strftime("%Y_%m_%d")

        log.info('Bulk rename using suffix _%s' % suffix)
        self._bulk_rename(
            [
                ('t_job_backup', 't_job_backup_%s' % suffix),
                ('t_file_backup', 't_file_backup_%s' % suffix),
                ('t_dm_backup', 't_dm_backup_%s' % suffix),
                ('t_job_backup_future', 't_job_backup'),
                ('t_file_backup_future', 't_file_backup'),
                ('t_dm_backup_future', 't_dm_backup')
            ]
        )

    def __init__(self, opts):
        """
        Constructor
        """
        self.engine = create_engine(opts.db_url)
        self.opts = opts

    def __call__(self):
        """
        Rotate the tables
        """
        self._lock()

        n_archive_entries = self._get_file_entries()
        log.info('%d entries in t_file_backup' % n_archive_entries)
        if n_archive_entries > self.opts.max_size:
            log.info('Limit reached. Rotating tables!')
            self._rotate_tables()
        else:
            log.info('Limit not reached. Not rotating tables.')

        self._unlock()


class MySqlRotate(BaseRotate):
    """
    Rotate MySQL/MariaDB tables
    """

    def _lock(self):
        ret = self._run_query(True, "SELECT GET_LOCK('fts_archive_rotate', 0)").fetchone()
        if ret[0] != 1:
            raise Exception('Could not acquire lock. Probably another fts_archive_rotate is running.')

    def _unlock(self):
        self._run_query(True, "SELECT RELEASE_LOCK('fts_archive_rotate')")

    def _get_file_entries(self):
        row = self._run_query(
            True,
            "SELECT table_rows "
            "FROM information_schema.tables "
            "WHERE table_name='t_file_backup' and table_schema=DATABASE()"
        ).fetchone()
        return row['table_rows']

    def _create_as(self, table, like):
        self._run_query(
            False,
            "CREATE TABLE IF NOT EXISTS %s ENGINE=%s (SELECT * FROM %s WHERE NULL)" % (table, self.opts.engine, like)
        )

    def _bulk_rename(self, tables):
        query = "RENAME TABLE " + ', '.join(map(lambda table: "%s TO %s" % (table[0], table[1]), tables))
        self._run_query(False, query)


class OracleRotate(BaseRotate):
    """
    Rotate Oracle tables
    """

    def _get_file_entries(self):
        row = self._run_query(
            True,
            "SELECT COUNT(*) "
            "FROM T_FILE_BACKUP"
        ).fetchone()
        return row[0]

    def _create_as(self, table, like):
        try:
            self._run_query(
                False,
                "CREATE TABLE %s AS (SELECT * FROM %s WHERE 0=1)" % (table, like)
            )
        except DatabaseError, ex:
            # Already exists
            if 'ORA-00955' not in str(ex):
                raise
            log.warning('Table %s already exists!' % table)

    def _bulk_rename(self, tables):
        for table in tables:
            self._run_query(False, "RENAME %s TO %s" % table)


def setup_logging(opts, config):
    """
    Setup the logging level and format
    """
    root_logger = logging.getLogger()

    if opts.stderr:
        log_stream = sys.stderr
    else:
        log_stream = open(os.path.join(config.get('fts3.ServerLogDirectory', '/var/log/fts3'), 'fts3server.log'), 'a')
    log_handler = logging.StreamHandler(log_stream)
    log_handler.setFormatter(
        logging.Formatter(
            fmt='%(levelname)-7s %(asctime)s; %(message)s',
            datefmt='%a %b %d %H:%M:%S %Y'
        )
    )
    root_logger.addHandler(log_handler)

    log_level = config.get('fts3.LogLevel', 'INFO')
    if log_level == 'TRACE':
        log_level = 'DEBUG'
    log_level = logging.getLevelName(log_level)
    log_handler.setLevel(log_level)
    root_logger.setLevel(log_level)


def table_rotate(opts):
    """
    Rotate archive tables
    """
    db_type = opts.db_url.split(':', 1)[0]
    if db_type == 'oracle':
        rotate = OracleRotate(opts)
    elif db_type == 'mysql':
        rotate = MySqlRotate(opts)
    else:
        raise Exception('Unknown database backend %s' % db_type)
    log.info('Rotating using backend %s' % db_type)
    rotate()


if __name__ == '__main__':
    parser = OptionParser(description='Rotate FTS3 archive tables')
    parser.add_option('-f', '--config-file', type=str, default='/etc/fts3/fts3config', help='FTS3 config file')
    parser.add_option(
        '--max-size', type=int, default=14000000,
        help='Rotate the tables when the number of transfers archived reach this number of entries'
    )
    parser.add_option('--engine', type=str, default='INNODB', help='MySQL engine to use for the archival tables')
    parser.add_option(
        '--db-url', type=str, default=None, help='Database URL. Overrides the configured in FTS3 config file'
    )
    parser.add_option(
        '--dry-run', default=False, action='store_true', help='Run preparation steps, log queries, but do not run them'
    )
    parser.add_option('--stderr', default=False, action='store_true', help='Log to stderr')

    opts, args = parser.parse_args()
    config = fts3_config_load(opts.config_file)

    if opts.db_url is None:
        opts.db_url = config['sqlalchemy.url']

    setup_logging(opts, config)

    try:
        table_rotate(opts)
    except Exception, e:
        log.critical('The archival failed: %s' % str(e))
        log.debug('\n'.join(traceback.format_exception(sys.exc_type, sys.exc_value, sys.exc_traceback)))
        sys.exit(-1)
