trove/reddwarf/extensions/mysql/common.py
Steve Leon 000744d3d6 Use database name validation only on listing and loading of databases
* We don't need to validate the database names in the
  response of a 'list_databases' call, since they may have been
  created via the root_enabled backdoor

fixes bug 1178089

Change-Id: I2f74e63cfd8b78feec9c38b5fec75138245a7f64
2013-05-30 18:04:51 -07:00

68 lines
2.3 KiB
Python

# Copyright 2012 OpenStack Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from reddwarf.common import exception
from reddwarf.guestagent.db import models as guest_models
from urllib import unquote
def populate_validated_databases(dbs):
"""
Create a serializable request with user provided data
for creating new databases.
"""
try:
databases = []
for database in dbs:
mydb = guest_models.ValidatedMySQLDatabase()
mydb.name = database.get('name', '')
mydb.character_set = database.get('character_set', '')
mydb.collate = database.get('collate', '')
databases.append(mydb.serialize())
return databases
except ValueError as ve:
# str(ve) contains user input and may include '%' which can cause a
# format str vulnerability. Escape the '%' to avoid this. This is
# okay to do since we're not using dict args here in any case.
safe_string = str(ve).replace('%', '%%')
raise exception.BadRequest(safe_string)
def populate_users(users):
"""Create a serializable request containing users"""
users_data = []
for user in users:
u = guest_models.MySQLUser()
u.name = user.get('name', '')
u.host = user.get('host')
u.password = user.get('password', '')
dbs = user.get('databases', '')
if dbs:
for db in dbs:
u.databases = db.get('name', '')
users_data.append(u.serialize())
return users_data
def unquote_user_host(user_hostname):
unquoted = unquote(user_hostname)
if '@' not in unquoted:
return unquoted, '%'
if unquoted.endswith('@'):
return unquoted, '%'
splitup = unquoted.split('@')
host = splitup[-1]
user = '@'.join(splitup[:-1])
return user, host