#!/usr/bin/env python
#
# Copyright 2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
"""  Secure the MySQL installation (/usr/bin/mysql_secure_installation) """

import json
import logging
import os

import MySQLdb

logging.basicConfig()
logger = logging.getLogger('mysql-secure')


# TODO(therese-mchale) Refactor to use a common
# load_userfile for this and 50-mysql-users
def load_userfile(path, users):
    if os.path.exists(path):
        with open(path) as dbusers_file:
            db_users = json.load(dbusers_file)
            if not isinstance(db_users, list):
                raise ValueError('%s must be a list' % (path))
            for dbvalues in db_users:
                username = dbvalues['username']
                users[username] = dbvalues


def secure_installation(rootuser):
    # If root password is set assumes ~/.my.cnf is configured correctly
    conn = MySQLdb.Connect(read_default_file=os.path.expanduser('~/.my.cnf'))
    with conn:
        # Remove Anonymous Users
        cursor = conn.cursor()
        cursor.execute("DELETE FROM mysql.user WHERE User = %s", "")
        # Remove remote root access
        cursor.execute("DELETE FROM mysql.user WHERE User = %s "
                       "AND Host NOT IN (%s, %s, %s)",
                       ("root", "localhost", "127.0.0.1", "::1"))
        # Remove the test database
        cursor.execute("DROP DATABASE IF EXISTS test")
        cursor.execute("DELETE FROM mysql.db WHERE Db=%s OR Db=%s",
                       ("test", "test\\_%"))
        # Make sure the root password for ALL root accounts
        # (i.e. 'localhost', '127.0.0.1', '::1' included) is set correctly
        if 'password' in rootuser:
            rootpwd = rootuser['password']
            cmd = ("UPDATE mysql.user SET Password=PASSWORD("
                   "%s) WHERE User=%s")
            cursor.execute(cmd, (rootpwd, "root"))
            cursor.execute("FLUSH PRIVILEGES")
        cursor.close()

users = {}
load_userfile('/etc/mysql/static-dbusers.json', users)
load_userfile('/etc/mysql/dbusers.json', users)

rootuser = users['root']
secure_installation(rootuser)
