Use a sensible SQLAlchemy session model

The existing db session strategy was inherited from a bunch of
shell scripts that ran once in a single thread and exited.

The surprising thing is that even worked at all.  This change
replaces that "strategy" with one where each thread clearly
begins a new session as a context manager and passes that around
to functions that need the DB.  A thread-local session is used
for convenience and extra safety.

This also adds a fake provider that will produce fake images and
servers quickly without needing a real nova or jenkins.  This was
used to develop the database change.

Also some minor logging changes and very brief developer docs.

Change-Id: I45e6564cb061f81d79c47a31e17f5d85cd1d9306
This commit is contained in:
James E. Blair 2013-08-16 20:17:57 -07:00
parent 35d66f0d77
commit a5a78ef441
6 changed files with 354 additions and 129 deletions

13
README Normal file
View File

@ -0,0 +1,13 @@
Developer setup:
mysql -u root
mysql> create database nodepool;
mysql> GRANT ALL ON nodepool.* TO 'nodepool'@'localhost';
mysql> flush privileges;
nodepool -d -c tools/fake.yaml
After each run (the fake nova provider is only in-memory):
mysql> delete from snapshot_image; delete from node;

110
nodepool/fakeprovider.py Normal file
View File

@ -0,0 +1,110 @@
#!/usr/bin/env python
#
# Copyright 2013 OpenStack Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import uuid
import time
import threading
import novaclient
class Dummy(object):
def __init__(self, **kw):
for k, v in kw.items():
setattr(self, k, v)
def delete(self):
self.manager.delete(self)
class FakeList(object):
def __init__(self, l):
self._list = l
def list(self):
return self._list
def find(self, name):
for x in self._list:
if x.name == name:
return x
def get(self, id):
for x in self._list:
if x.id == id:
return x
raise novaclient.exceptions.NotFound(404)
def _finish(self, obj, delay, status):
time.sleep(delay)
obj.status = status
def delete(self, obj):
self._list.remove(obj)
def create(self, **kw):
s = Dummy(id=uuid.uuid4().hex,
name=kw['name'],
status='BUILD',
addresses=dict(public=[dict(version=4, addr='fake')]),
manager=self)
self._list.append(s)
t = threading.Thread(target=self._finish, args=(s, 0.5, 'ACTIVE'))
t.start()
return s
def create_image(self, server, name):
x = self.api.images.create(name=name)
return x.id
class FakeHTTPClient(object):
def get(self, path):
if path == '/extensions':
return None, dict(extensions=dict())
class FakeClient(object):
def __init__(self):
self.flavors = FakeList([Dummy(id='f1', ram=8192)])
self.images = FakeList([Dummy(id='i1', name='Fake Precise')])
self.client = FakeHTTPClient()
self.servers = FakeList([])
self.servers.api = self
class FakeSSHClient(object):
def ssh(self, description, cmd):
return True
def scp(self, src, dest):
return True
class FakeJenkins(object):
def __init__(self):
self._nodes = {}
def node_exists(self, name):
return name in self._nodes
def create_node(self, name, **kw):
self._nodes[name] = kw
def delete_node(self, name):
del self._nodes[name]
FAKE_CLIENT = FakeClient()

View File

@ -42,7 +42,7 @@ STATE_NAMES = {
from sqlalchemy import Table, Column, Integer, String, \
MetaData, create_engine
from sqlalchemy.orm import mapper
from sqlalchemy.orm import scoped_session, mapper
from sqlalchemy.orm.session import Session, sessionmaker
metadata = MetaData()
@ -154,10 +154,27 @@ mapper(SnapshotImage, snapshot_image_table,
class NodeDatabase(object):
def __init__(self, dburi):
engine = create_engine(dburi, echo=False)
metadata.create_all(engine)
Session = sessionmaker(bind=engine, autoflush=True, autocommit=False)
self.session = Session()
self.engine = create_engine(dburi, echo=False)
metadata.create_all(self.engine)
self.session_factory = sessionmaker(bind=self.engine)
self.session = scoped_session(self.session_factory)
def getSession(self):
return NodeDatabaseSession(self.session)
class NodeDatabaseSession(object):
def __init__(self, session):
self.session = session
def __enter__(self):
return self
def __exit__(self, etype, value, tb):
if etype:
self.session().rollback()
else:
self.session().commit()
def print_state(self):
for provider_name in self.getProviders():
@ -182,40 +199,41 @@ class NodeDatabase(object):
node.state_time, node.ip)
def abort(self):
self.session.rollback()
self.session().rollback()
def commit(self):
self.session.commit()
self.session().commit()
def delete(self, obj):
self.session.delete(obj)
self.session().delete(obj)
def getProviders(self):
return [
x.provider_name for x in
self.session.query(SnapshotImage).distinct(
self.session().query(SnapshotImage).distinct(
snapshot_image_table.c.provider_name).all()]
def getImages(self, provider_name):
return [
x.image_name for x in
self.session.query(SnapshotImage).filter(
self.session().query(SnapshotImage).filter(
snapshot_image_table.c.provider_name == provider_name
).distinct(snapshot_image_table.c.image_name).all()]
def getSnapshotImages(self):
return self.session.query(SnapshotImage).order_by(
return self.session().query(SnapshotImage).order_by(
snapshot_image_table.c.provider_name,
snapshot_image_table.c.image_name).all()
def getSnapshotImage(self, id):
images = self.session.query(SnapshotImage).filter_by(id=id).all()
def getSnapshotImage(self, image_id):
images = self.session().query(SnapshotImage).filter_by(
id=image_id).all()
if not images:
return None
return images[0]
def getCurrentSnapshotImage(self, provider_name, image_name):
images = self.session.query(SnapshotImage).filter(
images = self.session().query(SnapshotImage).filter(
snapshot_image_table.c.provider_name == provider_name,
snapshot_image_table.c.image_name == image_name,
snapshot_image_table.c.state == READY).order_by(
@ -226,13 +244,13 @@ class NodeDatabase(object):
def createSnapshotImage(self, *args, **kwargs):
new = SnapshotImage(*args, **kwargs)
self.session.add(new)
self.session.commit()
self.session().add(new)
self.commit()
return new
def getNodes(self, provider_name=None, image_name=None, target_name=None,
state=None):
exp = self.session.query(Node).order_by(
exp = self.session().query(Node).order_by(
node_table.c.provider_name,
node_table.c.image_name)
if provider_name:
@ -247,24 +265,24 @@ class NodeDatabase(object):
def createNode(self, *args, **kwargs):
new = Node(*args, **kwargs)
self.session.add(new)
self.session.commit()
self.session().add(new)
self.commit()
return new
def getNode(self, id):
nodes = self.session.query(Node).filter_by(id=id).all()
nodes = self.session().query(Node).filter_by(id=id).all()
if not nodes:
return None
return nodes[0]
def getNodeByHostname(self, hostname):
nodes = self.session.query(Node).filter_by(hostname=hostname).all()
nodes = self.session().query(Node).filter_by(hostname=hostname).all()
if not nodes:
return None
return nodes[0]
def getNodeByNodename(self, nodename):
nodes = self.session.query(Node).filter_by(nodename=nodename).all()
nodes = self.session().query(Node).filter_by(nodename=nodename).all()
if not nodes:
return None
return nodes[0]

View File

@ -49,22 +49,22 @@ class NodeCompleteThread(threading.Thread):
threading.Thread.__init__(self)
self.nodename = nodename
self.nodepool = nodepool
self.db = nodedb.NodeDatabase(self.nodepool.config.dburi)
def run(self):
try:
self.handleEvent()
with self.nodepool.db.getSession() as session:
self.handleEvent(session)
except Exception:
self.log.exception("Exception handling event for %s:" %
self.nodename)
def handleEvent(self):
node = self.db.getNodeByNodename(self.nodename)
def handleEvent(self, session):
node = session.getNodeByNodename(self.nodename)
if not node:
self.log.debug("Unable to find node with nodename: %s" %
self.nodename)
return
self.nodepool.deleteNode(node)
self.nodepool.deleteNode(session, node)
class NodeUpdateListener(threading.Thread):
@ -78,7 +78,6 @@ class NodeUpdateListener(threading.Thread):
self.socket.setsockopt(zmq.SUBSCRIBE, event_filter)
self.socket.connect(addr)
self._stopped = False
self.db = nodedb.NodeDatabase(self.nodepool.config.dburi)
def run(self):
while not self._stopped:
@ -107,14 +106,15 @@ class NodeUpdateListener(threading.Thread):
topic)
def handleStartPhase(self, nodename):
node = self.db.getNodeByNodename(nodename)
with self.nodepool.db.getSession() as session:
node = session.getNodeByNodename(nodename)
if not node:
self.log.debug("Unable to find node with nodename: %s" %
nodename)
return
self.log.info("Setting node id: %s to USED" % node.id)
node.state = nodedb.USED
self.nodepool.updateStats(node.provider_name)
self.nodepool.updateStats(session, node.provider_name)
def handleCompletePhase(self, nodename):
t = NodeCompleteThread(self.nodepool, nodename)
@ -133,18 +133,24 @@ class NodeLauncher(threading.Thread):
self.nodepool = nodepool
def run(self):
try:
self._run()
except Exception:
self.log.exception("Exception in run method:")
def _run(self):
with self.nodepool.db.getSession() as session:
self.log.debug("Launching node id: %s" % self.node_id)
try:
self.db = nodedb.NodeDatabase(self.nodepool.config.dburi)
self.node = self.db.getNode(self.node_id)
self.node = session.getNode(self.node_id)
self.client = utils.get_client(self.provider)
except Exception:
self.log.exception("Exception preparing to launch node id: %s:" %
self.node_id)
self.log.exception("Exception preparing to launch node id: %s:"
% self.node_id)
return
try:
self.launchNode()
self.launchNode(session)
except Exception:
self.log.exception("Exception launching node id: %s:" %
self.node_id)
@ -155,7 +161,7 @@ class NodeLauncher(threading.Thread):
self.node_id)
return
def launchNode(self):
def launchNode(self, session):
start_time = time.time()
hostname = '%s-%s-%s.slave.openstack.org' % (
@ -165,7 +171,7 @@ class NodeLauncher(threading.Thread):
self.node.target_name = self.target.name
flavor = utils.get_flavor(self.client, self.image.min_ram)
snap_image = self.db.getCurrentSnapshotImage(
snap_image = session.getCurrentSnapshotImage(
self.provider.name, self.image.name)
if not snap_image:
raise Exception("Unable to find current snapshot image %s in %s" %
@ -179,7 +185,7 @@ class NodeLauncher(threading.Thread):
server, key = utils.create_server(self.client, hostname,
remote_snap_image, flavor)
self.node.external_id = server.id
self.db.commit()
session.commit()
self.log.debug("Waiting for server %s for node id: %s" %
(server.id, self.node.id))
@ -213,7 +219,7 @@ class NodeLauncher(threading.Thread):
# Jenkins might immediately use the node before we've updated
# the state:
self.node.state = nodedb.READY
self.nodepool.updateStats(self.provider.name)
self.nodepool.updateStats(session, self.provider.name)
self.log.info("Node id: %s is ready" % self.node.id)
if self.target.jenkins_url:
@ -222,7 +228,7 @@ class NodeLauncher(threading.Thread):
self.log.info("Node id: %s added to jenkins" % self.node.id)
def createJenkinsNode(self):
jenkins = myjenkins.Jenkins(self.target.jenkins_url,
jenkins = utils.get_jenkins(self.target.jenkins_url,
self.target.jenkins_user,
self.target.jenkins_apikey)
node_desc = 'Dynamic single use %s node' % self.image.name
@ -267,19 +273,27 @@ class ImageUpdater(threading.Thread):
self.scriptdir = self.nodepool.config.scriptdir
def run(self):
try:
self._run()
except Exception:
self.log.exception("Exception in run method:")
def _run(self):
with self.nodepool.db.getSession() as session:
self.log.debug("Updating image %s in %s " % (self.image.name,
self.provider.name))
try:
self.db = nodedb.NodeDatabase(self.nodepool.config.dburi)
self.snap_image = self.db.getSnapshotImage(self.snap_image_id)
self.snap_image = session.getSnapshotImage(
self.snap_image_id)
self.client = utils.get_client(self.provider)
except Exception:
self.log.exception("Exception preparing to update image %s in %s:"
% (self.image.name, self.provider.name))
self.log.exception("Exception preparing to update image %s "
"in %s:" % (self.image.name,
self.provider.name))
return
try:
self.updateImage()
self.updateImage(session)
except Exception:
self.log.exception("Exception updating image %s in %s:" %
(self.image.name, self.provider.name))
@ -291,7 +305,7 @@ class ImageUpdater(threading.Thread):
self.snap_image.id)
return
def updateImage(self):
def updateImage(self, session):
start_time = time.time()
timestamp = int(start_time)
@ -308,7 +322,7 @@ class ImageUpdater(threading.Thread):
self.snap_image.hostname = hostname
self.snap_image.version = timestamp
self.snap_image.server_external_id = server.id
self.db.commit()
session.commit()
self.log.debug("Image id: %s waiting for server %s" %
(self.snap_image.id, server.id))
@ -322,7 +336,7 @@ class ImageUpdater(threading.Thread):
image = utils.create_image(self.client, server, hostname)
self.snap_image.external_id = image.id
self.db.commit()
session.commit()
self.log.debug("Image id: %s building image %s" %
(self.snap_image.id, image.id))
# It can take a _very_ long time for Rackspace 1.0 to save an image
@ -339,6 +353,7 @@ class ImageUpdater(threading.Thread):
statsd.incr(key)
self.snap_image.state = nodedb.READY
session.commit()
self.log.info("Image %s in %s is ready" % (hostname,
self.provider.name))
@ -426,6 +441,7 @@ class NodePool(threading.Thread):
self.zmq_context = None
self.zmq_listeners = {}
self.db = None
self.dburi = None
self.apsched = apscheduler.scheduler.Scheduler()
self.apsched.start()
@ -452,7 +468,7 @@ class NodePool(threading.Thread):
self.apsched.unschedule_job(self.update_job)
parts = update_cron.split()
minute, hour, dom, month, dow = parts[:5]
self.apsched.add_cron_job(self.updateImages,
self.apsched.add_cron_job(self._doUpdateImages,
day=dom,
day_of_week=dow,
hour=hour,
@ -463,7 +479,7 @@ class NodePool(threading.Thread):
self.apsched.unschedule_job(self.cleanup_job)
parts = cleanup_cron.split()
minute, hour, dom, month, dow = parts[:5]
self.apsched.add_cron_job(self.periodicCleanup,
self.apsched.add_cron_job(self._doPeriodicCleanup,
day=dom,
day_of_week=dow,
hour=hour,
@ -524,6 +540,8 @@ class NodePool(threading.Thread):
i.providers[p.name] = p
p.min_ready = provider['min-ready']
self.config = newconfig
if self.config.dburi != self.dburi:
self.dburi = self.config.dburi
self.db = nodedb.NodeDatabase(self.config.dburi)
self.startUpdateListeners(config['zmq-publishers'])
@ -545,15 +563,15 @@ class NodePool(threading.Thread):
self.zmq_listeners[addr] = listener
listener.start()
def getNumNeededNodes(self, target, provider, image):
def getNumNeededNodes(self, session, target, provider, image):
# Count machines that are ready and machines that are building,
# so that if the provider is very slow, we aren't queueing up tons
# of machines to be built.
n_ready = len(self.db.getNodes(provider.name, image.name, target.name,
n_ready = len(session.getNodes(provider.name, image.name, target.name,
nodedb.READY))
n_building = len(self.db.getNodes(provider.name, image.name,
n_building = len(session.getNodes(provider.name, image.name,
target.name, nodedb.BUILDING))
n_provider = len(self.db.getNodes(provider.name))
n_provider = len(session.getNodes(provider.name))
num_to_launch = provider.min_ready - (n_ready + n_building)
# Don't launch more than our provider max
@ -567,30 +585,37 @@ class NodePool(threading.Thread):
def run(self):
while not self._stopped:
try:
self.loadConfig()
self.checkForMissingImages()
with self.db.getSession() as session:
self._run(session)
except Exception:
self.log.exception("Exception in main loop:")
time.sleep(WATERMARK_SLEEP)
def _run(self, session):
self.checkForMissingImages(session)
for target in self.config.targets.values():
self.log.debug("Examining target: %s" % target.name)
for image in target.images.values():
for provider in image.providers.values():
num_to_launch = self.getNumNeededNodes(
target, provider, image)
session, target, provider, image)
if num_to_launch:
self.log.info("Need to launch %s %s nodes for "
"%s on %s" %
(num_to_launch, image.name,
target.name, provider.name))
for i in range(num_to_launch):
snap_image = self.db.getCurrentSnapshotImage(
snap_image = session.getCurrentSnapshotImage(
provider.name, image.name)
if not snap_image:
self.log.debug("No current image for %s on %s"
% (provider.name, image.name))
else:
self.launchNode(provider, image, target)
time.sleep(WATERMARK_SLEEP)
self.launchNode(session, provider, image, target)
def checkForMissingImages(self):
def checkForMissingImages(self, session):
# If we are missing an image, run the image update function
# outside of its schedule.
missing = False
@ -598,7 +623,7 @@ class NodePool(threading.Thread):
for image in target.images.values():
for provider in image.providers.values():
found = False
for snap_image in self.db.getSnapshotImages():
for snap_image in session.getSnapshotImages():
if (snap_image.provider_name == provider.name and
snap_image.image_name == image.name and
snap_image.state in [nodedb.READY,
@ -609,14 +634,21 @@ class NodePool(threading.Thread):
(image.name, provider.name))
missing = True
if missing:
self.updateImages()
self.updateImages(session)
def updateImages(self):
def _doUpdateImages(self):
try:
with self.db.getSession() as session:
self.updateImages(session)
except Exception:
self.log.exception("Exception in periodic image update:")
def updateImages(self, session):
# This function should be run periodically to create new snapshot
# images.
for provider in self.config.providers.values():
for image in provider.images.values():
snap_image = self.db.createSnapshotImage(
snap_image = session.createSnapshotImage(
provider_name=provider.name,
image_name=image.name)
t = ImageUpdater(self, provider, image, snap_image.id)
@ -625,33 +657,33 @@ class NodePool(threading.Thread):
# Just to keep things clearer.
time.sleep(2)
def launchNode(self, provider, image, target):
def launchNode(self, session, provider, image, target):
provider = self.config.providers[provider.name]
image = provider.images[image.name]
node = self.db.createNode(provider.name, image.name, target.name)
node = session.createNode(provider.name, image.name, target.name)
t = NodeLauncher(self, provider, image, target, node.id)
t.start()
def deleteNode(self, node):
def deleteNode(self, session, node):
# Delete a node
start_time = time.time()
node.state = nodedb.DELETE
self.updateStats(node.provider_name)
self.updateStats(session, node.provider_name)
provider = self.config.providers[node.provider_name]
target = self.config.targets[node.target_name]
client = utils.get_client(provider)
if target.jenkins_url:
jenkins = myjenkins.Jenkins(target.jenkins_url,
jenkins = utils.get_jenkins(target.jenkins_url,
target.jenkins_user,
target.jenkins_apikey)
jenkins_name = node.nodename
if jenkins.node_exists(jenkins_name):
jenkins.delete_node(jenkins_name)
self.log.info("Deleted jenkins node ID: %s" % node.id)
self.log.info("Deleted jenkins node id: %s" % node.id)
utils.delete_node(client, node)
self.log.info("Deleted node ID: %s" % node.id)
self.log.info("Deleted node id: %s" % node.id)
if statsd:
dt = int((time.time() - start_time) * 1000)
@ -660,7 +692,7 @@ class NodePool(threading.Thread):
node.target_name)
statsd.timing(key, dt)
statsd.incr(key)
self.updateStats(node.provider_name)
self.updateStats(session, node.provider_name)
def deleteImage(self, snap_image):
# Delete a node
@ -669,16 +701,22 @@ class NodePool(threading.Thread):
client = utils.get_client(provider)
utils.delete_image(client, snap_image)
self.log.info("Deleted image ID: %s" % snap_image.id)
self.log.info("Deleted image id: %s" % snap_image.id)
def periodicCleanup(self):
def _doPeriodicCleanup(self):
try:
with self.db.getSession() as session:
self.periodicCleanup(session)
except Exception:
self.log.exception("Exception in periodic cleanup:")
def periodicCleanup(self, session):
# This function should be run periodically to clean up any hosts
# that may have slipped through the cracks, as well as to remove
# old images.
self.log.debug("Starting periodic cleanup")
db = nodedb.NodeDatabase(self.config.dburi)
for node in db.getNodes():
for node in session.getNodes():
if node.state in [nodedb.READY, nodedb.HOLD]:
continue
delete = False
@ -694,12 +732,12 @@ class NodePool(threading.Thread):
delete = True
if delete:
try:
self.deleteNode(node)
self.deleteNode(session, node)
except Exception:
self.log.exception("Exception deleting node ID: "
self.log.exception("Exception deleting node id: "
"%s" % node.id)
for image in db.getSnapshotImages():
for image in session.getSnapshotImages():
# Normally, reap images that have sat in their current state
# for 24 hours, unless the image is the current snapshot
delete = False
@ -713,7 +751,7 @@ class NodePool(threading.Thread):
self.log.info("Deleting image id: %s which has no current "
"base image" % image.id)
else:
current = db.getCurrentSnapshotImage(image.provider_name,
current = session.getCurrentSnapshotImage(image.provider_name,
image.image_name)
if (current and image != current and
(time.time() - current.state_time) > KEEP_OLD_IMAGE):
@ -729,11 +767,10 @@ class NodePool(threading.Thread):
image.id)
self.log.debug("Finished periodic cleanup")
def updateStats(self, provider_name):
def updateStats(self, session, provider_name):
if not statsd:
return
# This may be called outside of the main thread.
db = nodedb.NodeDatabase(self.config.dburi)
provider = self.config.providers[provider_name]
states = {}
@ -750,7 +787,7 @@ class NodePool(threading.Thread):
key = '%s.%s' % (base_key, state)
states[key] = 0
for node in db.getNodes():
for node in session.getNodes():
if node.state not in nodedb.STATE_NAMES:
continue
key = 'nodepool.target.%s.%s.%s.%s' % (

View File

@ -21,9 +21,11 @@ import time
import paramiko
import socket
import logging
import myjenkins
from sshclient import SSHClient
import nodedb
import fakeprovider
log = logging.getLogger("nodepool.utils")
@ -48,9 +50,18 @@ def get_client(provider):
kwargs['service_name'] = provider.service_name
if provider.region_name:
kwargs['region_name'] = provider.region_name
if provider.auth_url == 'fake':
return fakeprovider.FAKE_CLIENT
client = novaclient.client.Client(*args, **kwargs)
return client
def get_jenkins(url, user, apikey):
if apikey == 'fake':
return fakeprovider.FakeJenkins()
return myjenkins.Jenkins(url, user, apikey)
extension_cache = {}
@ -150,6 +161,8 @@ def wait_for_resource(wait_resource, timeout=3600):
def ssh_connect(ip, username, connect_kwargs={}, timeout=60):
if ip == 'fake':
return fakeprovider.FakeSSHClient()
# HPcloud may return errno 111 for about 30 seconds after adding the IP
for count in iterate_timeout(timeout, "ssh access"):
try:

34
tools/fake.yaml Normal file
View File

@ -0,0 +1,34 @@
script-dir: .
dburi: 'mysql://nodepool@localhost/nodepool'
cron:
cleanup: '*/1 * * * *'
update-image: '14 2 * * *'
zmq-publishers:
- tcp://localhost:8888
providers:
- name: fake-provider
username: 'fake'
password: 'fake'
auth-url: 'fake'
project-id: 'fake'
max-servers: 96
images:
- name: nodepool-fake
base-image: 'Fake Precise'
min-ram: 8192
setup: prepare_node_devstack.sh
targets:
- name: fake-jenkins
jenkins:
url: https://jenkins.example.org/
user: fake
apikey: fake
images:
- name: nodepool-fake
providers:
- name: fake-provider
min-ready: 6