Replace usage of LegacyEngineFacade

Switch to using oslo_db.sqlalchemy.enginefacade instead, as this is
required for SQLAlchemy 2.x support.

Change-Id: Ifcad28239b6907b8ca396d348cbfa54185355f68
This commit is contained in:
Matt Crees
2024-02-26 16:58:58 +00:00
parent 3e35eb05ed
commit f2c103535e
14 changed files with 589 additions and 644 deletions

View File

@@ -32,10 +32,11 @@ def run_migrations_online(target_metadata, version_table):
:param target_metadata: Model's metadata used for autogenerate support. :param target_metadata: Model's metadata used for autogenerate support.
:param version_table: Override the default version table for alembic. :param version_table: Override the default version table for alembic.
""" """
engine = db.get_engine() with db.session_for_write() as session:
with engine.connect() as connection: engine = session.get_bind()
context.configure(connection=connection, with engine.connect() as connection:
target_metadata=target_metadata, context.configure(connection=connection,
version_table=version_table) target_metadata=target_metadata,
with context.begin_transaction(): version_table=version_table)
context.run_migrations() with context.begin_transaction():
context.run_migrations()

View File

@@ -13,28 +13,26 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
# #
from oslo_config import cfg import threading
from oslo_db.sqlalchemy import session
from oslo_db.sqlalchemy import enginefacade
_CONTEXT = threading.local()
_FACADE = None _FACADE = None
def _create_facade_lazily(): def _create_facade_lazily():
global _FACADE global _FACADE
if _FACADE is None: if _FACADE is None:
# FIXME(priteau): Remove autocommit=True (and ideally use of ctx = enginefacade.transaction_context()
# LegacyEngineFacade) asap since it's not compatible with SQLAlchemy ctx.configure(sqlite_fk=True)
# 2.0. _FACADE = ctx
_FACADE = session.EngineFacade.from_config(cfg.CONF, sqlite_fk=True,
autocommit=True)
return _FACADE return _FACADE
def get_engine(): def session_for_read():
facade = _create_facade_lazily() return _create_facade_lazily().reader.using(_CONTEXT)
return facade.get_engine()
def get_session(**kwargs): def session_for_write():
facade = _create_facade_lazily() return _create_facade_lazily().writer.using(_CONTEXT)
return facade.get_session(**kwargs)

View File

@@ -30,16 +30,15 @@ def get_backend():
class State(api.State): class State(api.State):
def get_state(self, name): def get_state(self, name):
session = db.get_session() with db.session_for_read() as session:
q = utils.model_query( q = utils.model_query(
models.StateInfo, models.StateInfo,
session) session)
q = q.filter(models.StateInfo.name == name) q = q.filter(models.StateInfo.name == name)
return q.value(models.StateInfo.state) return q.value(models.StateInfo.state)
def set_state(self, name, state): def set_state(self, name, state):
session = db.get_session() with db.session_for_write() as session:
with session.begin():
try: try:
q = utils.model_query( q = utils.model_query(
models.StateInfo, models.StateInfo,
@@ -55,16 +54,15 @@ class State(api.State):
return db_state.state return db_state.state
def get_metadata(self, name): def get_metadata(self, name):
session = db.get_session() with db.session_for_read() as session:
q = utils.model_query( q = utils.model_query(
models.StateInfo, models.StateInfo,
session) session)
q.filter(models.StateInfo.name == name) q.filter(models.StateInfo.name == name)
return q.value(models.StateInfo.s_metadata) return q.value(models.StateInfo.s_metadata)
def set_metadata(self, name, metadata): def set_metadata(self, name, metadata):
session = db.get_session() with db.session_for_write() as session:
with session.begin():
try: try:
q = utils.model_query( q = utils.model_query(
models.StateInfo, models.StateInfo,
@@ -83,20 +81,19 @@ class ModuleInfo(api.ModuleInfo):
"""Base class for module info management.""" """Base class for module info management."""
def get_priority(self, name): def get_priority(self, name):
session = db.get_session() with db.session_for_read() as session:
q = utils.model_query( q = utils.model_query(
models.ModuleStateInfo, models.ModuleStateInfo,
session) session)
q = q.filter(models.ModuleStateInfo.name == name) q = q.filter(models.ModuleStateInfo.name == name)
res = q.value(models.ModuleStateInfo.priority) res = q.value(models.ModuleStateInfo.priority)
if res: if res:
return int(res) return int(res)
else: else:
return 1 return 1
def set_priority(self, name, priority): def set_priority(self, name, priority):
session = db.get_session() with db.session_for_write() as session:
with session.begin():
try: try:
q = utils.model_query( q = utils.model_query(
models.ModuleStateInfo, models.ModuleStateInfo,
@@ -113,20 +110,19 @@ class ModuleInfo(api.ModuleInfo):
return int(db_state.priority) return int(db_state.priority)
def get_state(self, name): def get_state(self, name):
session = db.get_session() with db.session_for_read() as session:
try: try:
q = utils.model_query( q = utils.model_query(
models.ModuleStateInfo, models.ModuleStateInfo,
session) session)
q = q.filter(models.ModuleStateInfo.name == name) q = q.filter(models.ModuleStateInfo.name == name)
res = q.value(models.ModuleStateInfo.state) res = q.value(models.ModuleStateInfo.state)
return bool(res) return bool(res)
except sqlalchemy.orm.exc.NoResultFound: except sqlalchemy.orm.exc.NoResultFound:
return None return None
def set_state(self, name, state): def set_state(self, name, state):
session = db.get_session() with db.session_for_write() as session:
with session.begin():
try: try:
q = utils.model_query( q = utils.model_query(
models.ModuleStateInfo, models.ModuleStateInfo,
@@ -145,20 +141,19 @@ class ServiceToCollectorMapping(object):
"""Base class for service to collector mapping.""" """Base class for service to collector mapping."""
def get_mapping(self, service): def get_mapping(self, service):
session = db.get_session() with db.session_for_read() as session:
try: try:
q = utils.model_query( q = utils.model_query(
models.ServiceToCollectorMapping, models.ServiceToCollectorMapping,
session) session)
q = q.filter( q = q.filter(
models.ServiceToCollectorMapping.service == service) models.ServiceToCollectorMapping.service == service)
return q.one() return q.one()
except sqlalchemy.orm.exc.NoResultFound: except sqlalchemy.orm.exc.NoResultFound:
raise api.NoSuchMapping(service) raise api.NoSuchMapping(service)
def set_mapping(self, service, collector): def set_mapping(self, service, collector):
session = db.get_session() with db.session_for_write() as session:
with session.begin():
try: try:
q = utils.model_query( q = utils.model_query(
models.ServiceToCollectorMapping, models.ServiceToCollectorMapping,
@@ -176,37 +171,37 @@ class ServiceToCollectorMapping(object):
return db_mapping return db_mapping
def list_services(self, collector=None): def list_services(self, collector=None):
session = db.get_session() with db.session_for_read() as session:
q = utils.model_query( q = utils.model_query(
models.ServiceToCollectorMapping, models.ServiceToCollectorMapping,
session) session)
if collector: if collector:
q = q.filter( q = q.filter(
models.ServiceToCollectorMapping.collector == collector) models.ServiceToCollectorMapping.collector == collector)
res = q.distinct().values( res = q.distinct().values(
models.ServiceToCollectorMapping.service) models.ServiceToCollectorMapping.service)
return res return res
def list_mappings(self, collector=None): def list_mappings(self, collector=None):
session = db.get_session() with db.session_for_read() as session:
q = utils.model_query( q = utils.model_query(
models.ServiceToCollectorMapping, models.ServiceToCollectorMapping,
session) session)
if collector: if collector:
q = q.filter( q = q.filter(
models.ServiceToCollectorMapping.collector == collector) models.ServiceToCollectorMapping.collector == collector)
res = q.all() res = q.all()
return res return res
def delete_mapping(self, service): def delete_mapping(self, service):
session = db.get_session() with db.session_for_write() as session:
q = utils.model_query( q = utils.model_query(
models.ServiceToCollectorMapping, models.ServiceToCollectorMapping,
session) session)
q = q.filter(models.ServiceToCollectorMapping.service == service) q = q.filter(models.ServiceToCollectorMapping.service == service)
r = q.delete() r = q.delete()
if not r: if not r:
raise api.NoSuchMapping(service) raise api.NoSuchMapping(service)
class DBAPIManager(object): class DBAPIManager(object):

View File

@@ -29,13 +29,11 @@ class HashMapBase(models.ModelBase):
'mysql_engine': "InnoDB"} 'mysql_engine': "InnoDB"}
fk_to_resolve = {} fk_to_resolve = {}
def save(self, session=None): def save(self):
from cloudkitty import db from cloudkitty import db
if session is None: with db.session_for_write() as session:
session = db.get_session() super(HashMapBase, self).save(session=session)
super(HashMapBase, self).save(session=session)
def as_dict(self): def as_dict(self):
d = {} d = {}

View File

@@ -34,131 +34,131 @@ class HashMap(api.HashMap):
return migration return migration
def get_service(self, name=None, uuid=None): def get_service(self, name=None, uuid=None):
session = db.get_session() with db.session_for_read() as session:
try: try:
q = session.query(models.HashMapService) q = session.query(models.HashMapService)
if name: if name:
q = q.filter( q = q.filter(
models.HashMapService.name == name) models.HashMapService.name == name)
elif uuid: elif uuid:
q = q.filter( q = q.filter(
models.HashMapService.service_id == uuid) models.HashMapService.service_id == uuid)
else: else:
raise api.ClientHashMapError( raise api.ClientHashMapError(
'You must specify either name or uuid.') 'You must specify either name or uuid.')
res = q.one() res = q.one()
return res return res
except sqlalchemy.orm.exc.NoResultFound: except sqlalchemy.orm.exc.NoResultFound:
raise api.NoSuchService(name=name, uuid=uuid) raise api.NoSuchService(name=name, uuid=uuid)
def get_field(self, uuid=None, service_uuid=None, name=None): def get_field(self, uuid=None, service_uuid=None, name=None):
session = db.get_session() with db.session_for_read() as session:
try: try:
q = session.query(models.HashMapField) q = session.query(models.HashMapField)
if uuid: if uuid:
q = q.filter( q = q.filter(
models.HashMapField.field_id == uuid) models.HashMapField.field_id == uuid)
elif service_uuid and name: elif service_uuid and name:
q = q.join( q = q.join(
models.HashMapField.service) models.HashMapField.service)
q = q.filter( q = q.filter(
models.HashMapService.service_id == service_uuid, models.HashMapService.service_id == service_uuid,
models.HashMapField.name == name) models.HashMapField.name == name)
else: else:
raise api.ClientHashMapError( raise api.ClientHashMapError(
'You must specify either a uuid' 'You must specify either a uuid'
' or a service_uuid and a name.') ' or a service_uuid and a name.')
res = q.one() res = q.one()
return res return res
except sqlalchemy.orm.exc.NoResultFound: except sqlalchemy.orm.exc.NoResultFound:
raise api.NoSuchField(uuid) raise api.NoSuchField(uuid)
def get_group(self, uuid=None, name=None): def get_group(self, uuid=None, name=None):
session = db.get_session() with db.session_for_read() as session:
try: try:
q = session.query(models.HashMapGroup) q = session.query(models.HashMapGroup)
if uuid: if uuid:
q = q.filter( q = q.filter(
models.HashMapGroup.group_id == uuid) models.HashMapGroup.group_id == uuid)
if name: if name:
q = q.filter( q = q.filter(
models.HashMapGroup.name == name) models.HashMapGroup.name == name)
res = q.one() res = q.one()
return res return res
except sqlalchemy.orm.exc.NoResultFound: except sqlalchemy.orm.exc.NoResultFound:
raise api.NoSuchGroup(name, uuid) raise api.NoSuchGroup(name, uuid)
def get_mapping(self, uuid): def get_mapping(self, uuid):
session = db.get_session() with db.session_for_read() as session:
try: try:
q = session.query(models.HashMapMapping) q = session.query(models.HashMapMapping)
q = q.filter( q = q.filter(
models.HashMapMapping.mapping_id == uuid) models.HashMapMapping.mapping_id == uuid)
res = q.one() res = q.one()
return res return res
except sqlalchemy.orm.exc.NoResultFound: except sqlalchemy.orm.exc.NoResultFound:
raise api.NoSuchMapping(uuid) raise api.NoSuchMapping(uuid)
def get_threshold(self, uuid): def get_threshold(self, uuid):
session = db.get_session() with db.session_for_read() as session:
try: try:
q = session.query(models.HashMapThreshold) q = session.query(models.HashMapThreshold)
q = q.filter( q = q.filter(
models.HashMapThreshold.threshold_id == uuid) models.HashMapThreshold.threshold_id == uuid)
res = q.one() res = q.one()
return res return res
except sqlalchemy.orm.exc.NoResultFound: except sqlalchemy.orm.exc.NoResultFound:
raise api.NoSuchThreshold(uuid) raise api.NoSuchThreshold(uuid)
def get_group_from_mapping(self, uuid): def get_group_from_mapping(self, uuid):
session = db.get_session() with db.session_for_read() as session:
try: try:
q = session.query(models.HashMapGroup) q = session.query(models.HashMapGroup)
q = q.join( q = q.join(
models.HashMapGroup.mappings) models.HashMapGroup.mappings)
q = q.filter( q = q.filter(
models.HashMapMapping.mapping_id == uuid) models.HashMapMapping.mapping_id == uuid)
res = q.one() res = q.one()
return res return res
except sqlalchemy.orm.exc.NoResultFound: except sqlalchemy.orm.exc.NoResultFound:
raise api.MappingHasNoGroup(uuid=uuid) raise api.MappingHasNoGroup(uuid=uuid)
def get_group_from_threshold(self, uuid): def get_group_from_threshold(self, uuid):
session = db.get_session() with db.session_for_read() as session:
try: try:
q = session.query(models.HashMapGroup) q = session.query(models.HashMapGroup)
q = q.join( q = q.join(
models.HashMapGroup.thresholds) models.HashMapGroup.thresholds)
q = q.filter( q = q.filter(
models.HashMapThreshold.threshold_id == uuid) models.HashMapThreshold.threshold_id == uuid)
res = q.one() res = q.one()
return res return res
except sqlalchemy.orm.exc.NoResultFound: except sqlalchemy.orm.exc.NoResultFound:
raise api.ThresholdHasNoGroup(uuid=uuid) raise api.ThresholdHasNoGroup(uuid=uuid)
def list_services(self): def list_services(self):
session = db.get_session() with db.session_for_read() as session:
q = session.query(models.HashMapService) q = session.query(models.HashMapService)
res = q.values( res = q.values(
models.HashMapService.service_id) models.HashMapService.service_id)
return [uuid[0] for uuid in res] return [uuid[0] for uuid in res]
def list_fields(self, service_uuid): def list_fields(self, service_uuid):
session = db.get_session() with db.session_for_read() as session:
q = session.query(models.HashMapField) q = session.query(models.HashMapField)
q = q.join( q = q.join(
models.HashMapField.service) models.HashMapField.service)
q = q.filter( q = q.filter(
models.HashMapService.service_id == service_uuid) models.HashMapService.service_id == service_uuid)
res = q.values(models.HashMapField.field_id) res = q.values(models.HashMapField.field_id)
return [uuid[0] for uuid in res] return [uuid[0] for uuid in res]
def list_groups(self): def list_groups(self):
session = db.get_session() with db.session_for_read() as session:
q = session.query(models.HashMapGroup) q = session.query(models.HashMapGroup)
res = q.values( res = q.values(
models.HashMapGroup.group_id) models.HashMapGroup.group_id)
return [uuid[0] for uuid in res] return [uuid[0] for uuid in res]
def list_mappings(self, def list_mappings(self,
service_uuid=None, service_uuid=None,
@@ -167,33 +167,34 @@ class HashMap(api.HashMap):
no_group=False, no_group=False,
**kwargs): **kwargs):
session = db.get_session() with db.session_for_read() as session:
q = session.query(models.HashMapMapping) q = session.query(models.HashMapMapping)
if service_uuid: if service_uuid:
q = q.join( q = q.join(
models.HashMapMapping.service) models.HashMapMapping.service)
q = q.filter( q = q.filter(
models.HashMapService.service_id == service_uuid) models.HashMapService.service_id == service_uuid)
elif field_uuid: elif field_uuid:
q = q.join( q = q.join(
models.HashMapMapping.field) models.HashMapMapping.field)
q = q.filter(models.HashMapField.field_id == field_uuid) q = q.filter(models.HashMapField.field_id == field_uuid)
elif not service_uuid and not field_uuid and not group_uuid: elif not service_uuid and not field_uuid and not group_uuid:
raise api.ClientHashMapError( raise api.ClientHashMapError(
'You must specify either service_uuid,' 'You must specify either service_uuid,'
' field_uuid or group_uuid.') ' field_uuid or group_uuid.')
if 'tenant_uuid' in kwargs: if 'tenant_uuid' in kwargs:
q = q.filter( q = q.filter(
models.HashMapMapping.tenant_id == kwargs.get('tenant_uuid')) models.HashMapMapping.tenant_id == kwargs.get(
if group_uuid: 'tenant_uuid'))
q = q.join( if group_uuid:
models.HashMapMapping.group) q = q.join(
q = q.filter(models.HashMapGroup.group_id == group_uuid) models.HashMapMapping.group)
elif no_group: q = q.filter(models.HashMapGroup.group_id == group_uuid)
q = q.filter(models.HashMapMapping.group_id == None) # noqa elif no_group:
res = q.values( q = q.filter(models.HashMapMapping.group_id == None) # noqa
models.HashMapMapping.mapping_id) res = q.values(
return [uuid[0] for uuid in res] models.HashMapMapping.mapping_id)
return [uuid[0] for uuid in res]
def list_thresholds(self, def list_thresholds(self,
service_uuid=None, service_uuid=None,
@@ -202,38 +203,38 @@ class HashMap(api.HashMap):
no_group=False, no_group=False,
**kwargs): **kwargs):
session = db.get_session() with db.session_for_read() as session:
q = session.query(models.HashMapThreshold) q = session.query(models.HashMapThreshold)
if service_uuid: if service_uuid:
q = q.join( q = q.join(
models.HashMapThreshold.service) models.HashMapThreshold.service)
q = q.filter( q = q.filter(
models.HashMapService.service_id == service_uuid) models.HashMapService.service_id == service_uuid)
elif field_uuid: elif field_uuid:
q = q.join( q = q.join(
models.HashMapThreshold.field) models.HashMapThreshold.field)
q = q.filter(models.HashMapField.field_id == field_uuid) q = q.filter(models.HashMapField.field_id == field_uuid)
elif not service_uuid and not field_uuid and not group_uuid: elif not service_uuid and not field_uuid and not group_uuid:
raise api.ClientHashMapError( raise api.ClientHashMapError(
'You must specify either service_uuid,' 'You must specify either service_uuid,'
' field_uuid or group_uuid.') ' field_uuid or group_uuid.')
if 'tenant_uuid' in kwargs: if 'tenant_uuid' in kwargs:
q = q.filter( q = q.filter(
models.HashMapThreshold.tenant_id == kwargs.get('tenant_uuid')) models.HashMapThreshold.tenant_id == kwargs.get(
if group_uuid: 'tenant_uuid'))
q = q.join( if group_uuid:
models.HashMapThreshold.group) q = q.join(
q = q.filter(models.HashMapGroup.group_id == group_uuid) models.HashMapThreshold.group)
elif no_group: q = q.filter(models.HashMapGroup.group_id == group_uuid)
q = q.filter(models.HashMapThreshold.group_id == None) # noqa elif no_group:
res = q.values( q = q.filter(models.HashMapThreshold.group_id == None) # noqa
models.HashMapThreshold.threshold_id) res = q.values(
return [uuid[0] for uuid in res] models.HashMapThreshold.threshold_id)
return [uuid[0] for uuid in res]
def create_service(self, name): def create_service(self, name):
session = db.get_session()
try: try:
with session.begin(): with db.session_for_write() as session:
service_db = models.HashMapService(name=name) service_db = models.HashMapService(name=name)
service_db.service_id = uuidutils.generate_uuid() service_db.service_id = uuidutils.generate_uuid()
session.add(service_db) session.add(service_db)
@@ -246,9 +247,8 @@ class HashMap(api.HashMap):
def create_field(self, service_uuid, name): def create_field(self, service_uuid, name):
service_db = self.get_service(uuid=service_uuid) service_db = self.get_service(uuid=service_uuid)
session = db.get_session()
try: try:
with session.begin(): with db.session_for_write() as session:
field_db = models.HashMapField( field_db = models.HashMapField(
service_id=service_db.id, service_id=service_db.id,
name=name, name=name,
@@ -264,9 +264,8 @@ class HashMap(api.HashMap):
return field_db return field_db
def create_group(self, name): def create_group(self, name):
session = db.get_session()
try: try:
with session.begin(): with db.session_for_write() as session:
group_db = models.HashMapGroup( group_db = models.HashMapGroup(
name=name, name=name,
group_id=uuidutils.generate_uuid()) group_id=uuidutils.generate_uuid())
@@ -308,9 +307,8 @@ class HashMap(api.HashMap):
if group_id: if group_id:
group_db = self.get_group(uuid=group_id) group_db = self.get_group(uuid=group_id)
group_fk = group_db.id group_fk = group_db.id
session = db.get_session()
try: try:
with session.begin(): with db.session_for_write() as session:
field_map = models.HashMapMapping( field_map = models.HashMapMapping(
mapping_id=uuidutils.generate_uuid(), mapping_id=uuidutils.generate_uuid(),
value=value, value=value,
@@ -365,9 +363,8 @@ class HashMap(api.HashMap):
if group_id: if group_id:
group_db = self.get_group(uuid=group_id) group_db = self.get_group(uuid=group_id)
group_fk = group_db.id group_fk = group_db.id
session = db.get_session()
try: try:
with session.begin(): with db.session_for_write() as session:
threshold_db = models.HashMapThreshold( threshold_db = models.HashMapThreshold(
threshold_id=uuidutils.generate_uuid(), threshold_id=uuidutils.generate_uuid(),
level=level, level=level,
@@ -395,9 +392,8 @@ class HashMap(api.HashMap):
return threshold_db return threshold_db
def update_mapping(self, uuid, **kwargs): def update_mapping(self, uuid, **kwargs):
session = db.get_session()
try: try:
with session.begin(): with db.session_for_write() as session:
q = session.query(models.HashMapMapping) q = session.query(models.HashMapMapping)
q = q.filter( q = q.filter(
models.HashMapMapping.mapping_id == uuid) models.HashMapMapping.mapping_id == uuid)
@@ -442,9 +438,8 @@ class HashMap(api.HashMap):
raise api.NoSuchMapping(uuid) raise api.NoSuchMapping(uuid)
def update_threshold(self, uuid, **kwargs): def update_threshold(self, uuid, **kwargs):
session = db.get_session()
try: try:
with session.begin(): with db.session_for_write() as session:
q = session.query(models.HashMapThreshold) q = session.query(models.HashMapThreshold)
q = q.filter( q = q.filter(
models.HashMapThreshold.threshold_id == uuid) models.HashMapThreshold.threshold_id == uuid)
@@ -483,38 +478,37 @@ class HashMap(api.HashMap):
raise api.NoSuchThreshold(uuid) raise api.NoSuchThreshold(uuid)
def delete_service(self, name=None, uuid=None): def delete_service(self, name=None, uuid=None):
session = db.get_session() with db.session_for_write() as session:
q = utils.model_query( q = utils.model_query(
models.HashMapService, models.HashMapService,
session) session)
if name: if name:
q = q.filter(models.HashMapService.name == name) q = q.filter(models.HashMapService.name == name)
elif uuid: elif uuid:
q = q.filter(models.HashMapService.service_id == uuid) q = q.filter(models.HashMapService.service_id == uuid)
else: else:
raise api.ClientHashMapError( raise api.ClientHashMapError(
'You must specify either name or uuid.') 'You must specify either name or uuid.')
r = q.delete() r = q.delete()
if not r: if not r:
raise api.NoSuchService(name, uuid) raise api.NoSuchService(name, uuid)
def delete_field(self, uuid): def delete_field(self, uuid):
session = db.get_session() with db.session_for_write() as session:
q = utils.model_query( q = utils.model_query(
models.HashMapField, models.HashMapField,
session) session)
q = q.filter(models.HashMapField.field_id == uuid) q = q.filter(models.HashMapField.field_id == uuid)
r = q.delete() r = q.delete()
if not r: if not r:
raise api.NoSuchField(uuid) raise api.NoSuchField(uuid)
def delete_group(self, uuid, recurse=True): def delete_group(self, uuid, recurse=True):
session = db.get_session() with db.session_for_write() as session:
q = utils.model_query( q = utils.model_query(
models.HashMapGroup, models.HashMapGroup,
session) session)
q = q.filter(models.HashMapGroup.group_id == uuid) q = q.filter(models.HashMapGroup.group_id == uuid)
with session.begin():
try: try:
r = q.with_for_update().one() r = q.with_for_update().one()
except sqlalchemy.orm.exc.NoResultFound: except sqlalchemy.orm.exc.NoResultFound:
@@ -527,21 +521,21 @@ class HashMap(api.HashMap):
q.delete() q.delete()
def delete_mapping(self, uuid): def delete_mapping(self, uuid):
session = db.get_session() with db.session_for_write() as session:
q = utils.model_query( q = utils.model_query(
models.HashMapMapping, models.HashMapMapping,
session) session)
q = q.filter(models.HashMapMapping.mapping_id == uuid) q = q.filter(models.HashMapMapping.mapping_id == uuid)
r = q.delete() r = q.delete()
if not r: if not r:
raise api.NoSuchMapping(uuid) raise api.NoSuchMapping(uuid)
def delete_threshold(self, uuid): def delete_threshold(self, uuid):
session = db.get_session() with db.session_for_write() as session:
q = utils.model_query( q = utils.model_query(
models.HashMapThreshold, models.HashMapThreshold,
session) session)
q = q.filter(models.HashMapThreshold.threshold_id == uuid) q = q.filter(models.HashMapThreshold.threshold_id == uuid)
r = q.delete() r = q.delete()
if not r: if not r:
raise api.NoSuchThreshold(uuid) raise api.NoSuchThreshold(uuid)

View File

@@ -29,13 +29,11 @@ class HashMapBase(models.ModelBase):
'mysql_engine': "InnoDB"} 'mysql_engine': "InnoDB"}
fk_to_resolve = {} fk_to_resolve = {}
def save(self, session=None): def save(self):
from cloudkitty import db from cloudkitty import db
if session is None: with db.session_for_write() as session:
session = db.get_session() super(HashMapBase, self).save(session=session)
super(HashMapBase, self).save(session=session)
def as_dict(self): def as_dict(self):
d = {} d = {}

View File

@@ -34,33 +34,32 @@ class PyScripts(api.PyScripts):
return migration return migration
def get_script(self, name=None, uuid=None): def get_script(self, name=None, uuid=None):
session = db.get_session() with db.session_for_read() as session:
try: try:
q = session.query(models.PyScriptsScript) q = session.query(models.PyScriptsScript)
if name: if name:
q = q.filter( q = q.filter(
models.PyScriptsScript.name == name) models.PyScriptsScript.name == name)
elif uuid: elif uuid:
q = q.filter( q = q.filter(
models.PyScriptsScript.script_id == uuid) models.PyScriptsScript.script_id == uuid)
else: else:
raise ValueError('You must specify either name or uuid.') raise ValueError('You must specify either name or uuid.')
res = q.one() res = q.one()
return res return res
except sqlalchemy.orm.exc.NoResultFound: except sqlalchemy.orm.exc.NoResultFound:
raise api.NoSuchScript(name=name, uuid=uuid) raise api.NoSuchScript(name=name, uuid=uuid)
def list_scripts(self): def list_scripts(self):
session = db.get_session() with db.session_for_read() as session:
q = session.query(models.PyScriptsScript) q = session.query(models.PyScriptsScript)
res = q.values( res = q.values(
models.PyScriptsScript.script_id) models.PyScriptsScript.script_id)
return [uuid[0] for uuid in res] return [uuid[0] for uuid in res]
def create_script(self, name, data): def create_script(self, name, data):
session = db.get_session()
try: try:
with session.begin(): with db.session_for_write() as session:
script_db = models.PyScriptsScript(name=name) script_db = models.PyScriptsScript(name=name)
script_db.data = data script_db.data = data
script_db.script_id = uuidutils.generate_uuid() script_db.script_id = uuidutils.generate_uuid()
@@ -73,9 +72,8 @@ class PyScripts(api.PyScripts):
script_db.script_id) script_db.script_id)
def update_script(self, uuid, **kwargs): def update_script(self, uuid, **kwargs):
session = db.get_session()
try: try:
with session.begin(): with db.session_for_write() as session:
q = session.query(models.PyScriptsScript) q = session.query(models.PyScriptsScript)
q = q.filter( q = q.filter(
models.PyScriptsScript.script_id == uuid models.PyScriptsScript.script_id == uuid
@@ -99,16 +97,16 @@ class PyScripts(api.PyScripts):
raise api.NoSuchScript(uuid=uuid) raise api.NoSuchScript(uuid=uuid)
def delete_script(self, name=None, uuid=None): def delete_script(self, name=None, uuid=None):
session = db.get_session() with db.session_for_write() as session:
q = utils.model_query( q = utils.model_query(
models.PyScriptsScript, models.PyScriptsScript,
session) session)
if name: if name:
q = q.filter(models.PyScriptsScript.name == name) q = q.filter(models.PyScriptsScript.name == name)
elif uuid: elif uuid:
q = q.filter(models.PyScriptsScript.script_id == uuid) q = q.filter(models.PyScriptsScript.script_id == uuid)
else: else:
raise ValueError('You must specify either name or uuid.') raise ValueError('You must specify either name or uuid.')
r = q.delete() r = q.delete()
if not r: if not r:
raise api.NoSuchScript(uuid=uuid) raise api.NoSuchScript(uuid=uuid)

View File

@@ -29,13 +29,11 @@ class PyScriptsBase(models.ModelBase):
'mysql_engine': "InnoDB"} 'mysql_engine': "InnoDB"}
fk_to_resolve = {} fk_to_resolve = {}
def save(self, session=None): def save(self):
from cloudkitty import db from cloudkitty import db
if session is None: with db.session_for_write() as session:
session = db.get_session() super(PyScriptsBase, self).save(session=session)
super(PyScriptsBase, self).save(session=session)
def as_dict(self): def as_dict(self):
d = {} d = {}

View File

@@ -56,45 +56,37 @@ class HybridStorage(BaseStorage):
HYBRID_BACKENDS_NAMESPACE, HYBRID_BACKENDS_NAMESPACE,
cfg.CONF.storage_hybrid.backend, cfg.CONF.storage_hybrid.backend,
invoke_on_load=True).driver invoke_on_load=True).driver
self._sql_session = {}
def _check_session(self, tenant_id):
session = self._sql_session.get(tenant_id, None)
if not session:
self._sql_session[tenant_id] = db.get_session()
self._sql_session[tenant_id].begin()
def init(self): def init(self):
migration.upgrade('head') migration.upgrade('head')
self._hybrid_backend.init() self._hybrid_backend.init()
def get_state(self, tenant_id=None): def get_state(self, tenant_id=None):
session = db.get_session() with db.session_for_read() as session:
q = utils.model_query(self.state_model, session) q = utils.model_query(self.state_model, session)
if tenant_id: if tenant_id:
q = q.filter(self.state_model.tenant_id == tenant_id) q = q.filter(self.state_model.tenant_id == tenant_id)
q = q.order_by(self.state_model.state.desc()) q = q.order_by(self.state_model.state.desc())
r = q.first() r = q.first()
return r.state if r else None return r.state if r else None
def _set_state(self, tenant_id, state): def _set_state(self, tenant_id, state):
self._check_session(tenant_id) with db.session_for_write() as session:
session = self._sql_session[tenant_id] q = utils.model_query(self.state_model, session)
q = utils.model_query(self.state_model, session) if tenant_id:
if tenant_id: q = q.filter(self.state_model.tenant_id == tenant_id)
q = q.filter(self.state_model.tenant_id == tenant_id) r = q.first()
r = q.first() do_commit = False
do_commit = False if r:
if r: if state > r.state:
if state > r.state: q.update({'state': state})
q.update({'state': state}) do_commit = True
else:
state = self.state_model(tenant_id=tenant_id, state=state)
session.add(state)
do_commit = True do_commit = True
else: if do_commit:
state = self.state_model(tenant_id=tenant_id, state=state) session.commit()
session.add(state)
do_commit = True
if do_commit:
session.commit()
def _commit(self, tenant_id): def _commit(self, tenant_id):
self._hybrid_backend.commit(tenant_id, self.get_state(tenant_id)) self._hybrid_backend.commit(tenant_id, self.get_state(tenant_id))
@@ -105,7 +97,6 @@ class HybridStorage(BaseStorage):
def _post_commit(self, tenant_id): def _post_commit(self, tenant_id):
self._set_state(tenant_id, self.usage_start_dt.get(tenant_id)) self._set_state(tenant_id, self.usage_start_dt.get(tenant_id))
super(HybridStorage, self)._post_commit(tenant_id) super(HybridStorage, self)._post_commit(tenant_id)
del self._sql_session[tenant_id]
def get_total(self, begin=None, end=None, tenant_id=None, def get_total(self, begin=None, end=None, tenant_id=None,
service=None, groupby=None): service=None, groupby=None):

View File

@@ -35,140 +35,131 @@ class SQLAlchemyStorage(storage.BaseStorage):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(SQLAlchemyStorage, self).__init__(**kwargs) super(SQLAlchemyStorage, self).__init__(**kwargs)
self._session = {}
@staticmethod @staticmethod
def init(): def init():
migration.upgrade('head') migration.upgrade('head')
def _pre_commit(self, tenant_id): def _pre_commit(self, tenant_id):
self._check_session(tenant_id)
if not self._has_data.get(tenant_id): if not self._has_data.get(tenant_id):
empty_frame = {'vol': {'qty': 0, 'unit': 'None'}, empty_frame = {'vol': {'qty': 0, 'unit': 'None'},
'rating': {'price': 0}, 'desc': ''} 'rating': {'price': 0}, 'desc': ''}
self._append_time_frame('_NO_DATA_', empty_frame, tenant_id) self._append_time_frame('_NO_DATA_', empty_frame, tenant_id)
def _commit(self, tenant_id): def _commit(self, tenant_id):
self._session[tenant_id].commit() super(SQLAlchemyStorage, self)._commit(tenant_id)
def _post_commit(self, tenant_id): def _post_commit(self, tenant_id):
super(SQLAlchemyStorage, self)._post_commit(tenant_id) super(SQLAlchemyStorage, self)._post_commit(tenant_id)
del self._session[tenant_id]
def _check_session(self, tenant_id):
session = self._session.get(tenant_id)
if not session:
self._session[tenant_id] = db.get_session()
self._session[tenant_id].begin()
def _dispatch(self, data, tenant_id): def _dispatch(self, data, tenant_id):
self._check_session(tenant_id)
for service in data: for service in data:
for frame in data[service]: for frame in data[service]:
self._append_time_frame(service, frame, tenant_id) self._append_time_frame(service, frame, tenant_id)
self._has_data[tenant_id] = True self._has_data[tenant_id] = True
def get_state(self, tenant_id=None): def get_state(self, tenant_id=None):
session = db.get_session() with db.session_for_read() as session:
q = utils.model_query( q = utils.model_query(
self.frame_model, self.frame_model,
session) session)
if tenant_id: if tenant_id:
q = q.filter( q = q.filter(
self.frame_model.tenant_id == tenant_id) self.frame_model.tenant_id == tenant_id)
q = q.order_by( q = q.order_by(
self.frame_model.begin.desc()) self.frame_model.begin.desc())
r = q.first() r = q.first()
if r: if r:
return r.begin return r.begin
def get_total(self, begin=None, end=None, tenant_id=None, service=None, def get_total(self, begin=None, end=None, tenant_id=None, service=None,
groupby=None): groupby=None):
session = db.get_session() with db.session_for_read() as session:
querymodels = [ querymodels = [
sqlalchemy.func.sum(self.frame_model.rate).label('rate') sqlalchemy.func.sum(self.frame_model.rate).label('rate')
] ]
if not begin: if not begin:
begin = ck_utils.get_month_start_timestamp() begin = ck_utils.get_month_start_timestamp()
if not end: if not end:
end = ck_utils.get_next_month_timestamp() end = ck_utils.get_next_month_timestamp()
# Boundary calculation # Boundary calculation
if tenant_id: if tenant_id:
querymodels.append(self.frame_model.tenant_id) querymodels.append(self.frame_model.tenant_id)
if service: if service:
querymodels.append(self.frame_model.res_type) querymodels.append(self.frame_model.res_type)
if groupby: if groupby:
groupbyfields = groupby.split(",") groupbyfields = groupby.split(",")
for field in groupbyfields: for field in groupbyfields:
field_obj = self.frame_model.__dict__.get(field, None) field_obj = self.frame_model.__dict__.get(field, None)
if field_obj and field_obj not in querymodels: if field_obj and field_obj not in querymodels:
querymodels.append(field_obj) querymodels.append(field_obj)
q = session.query(*querymodels) q = session.query(*querymodels)
if tenant_id: if tenant_id:
q = q.filter(
self.frame_model.tenant_id == tenant_id)
if service:
q = q.filter(
self.frame_model.res_type == service)
# begin and end filters are both needed, do not remove one of them.
q = q.filter( q = q.filter(
self.frame_model.tenant_id == tenant_id) self.frame_model.begin.between(begin, end),
if service: self.frame_model.end.between(begin, end),
q = q.filter( self.frame_model.res_type != '_NO_DATA_')
self.frame_model.res_type == service) if groupby:
# begin and end filters are both needed, do not remove one of them. q = q.group_by(sqlalchemy.sql.text(groupby))
q = q.filter(
self.frame_model.begin.between(begin, end),
self.frame_model.end.between(begin, end),
self.frame_model.res_type != '_NO_DATA_')
if groupby:
q = q.group_by(sqlalchemy.sql.text(groupby))
# Order by sum(rate) # Order by sum(rate)
q = q.order_by(sqlalchemy.func.sum(self.frame_model.rate)) q = q.order_by(sqlalchemy.func.sum(self.frame_model.rate))
results = q.all() results = q.all()
totallist = [] totallist = []
for r in results: for r in results:
total = {model.name: value for model, value in zip(querymodels, r)} total = {model.name: value for model, value in zip(querymodels,
total["begin"] = begin r)}
total["end"] = end total["begin"] = begin
totallist.append(total) total["end"] = end
totallist.append(total)
return totallist return totallist
def get_tenants(self, begin, end): def get_tenants(self, begin, end):
session = db.get_session() with db.session_for_read() as session:
q = utils.model_query( q = utils.model_query(
self.frame_model, self.frame_model,
session) session)
# begin and end filters are both needed, do not remove one of them. # begin and end filters are both needed, do not remove one of them.
q = q.filter( q = q.filter(
self.frame_model.begin.between(begin, end), self.frame_model.begin.between(begin, end),
self.frame_model.end.between(begin, end)) self.frame_model.end.between(begin, end))
tenants = q.distinct().values( tenants = q.distinct().values(
self.frame_model.tenant_id) self.frame_model.tenant_id)
return [tenant.tenant_id for tenant in tenants] return [tenant.tenant_id for tenant in tenants]
def get_time_frame(self, begin, end, **filters): def get_time_frame(self, begin, end, **filters):
if not begin: if not begin:
begin = ck_utils.get_month_start() begin = ck_utils.get_month_start()
if not end: if not end:
end = ck_utils.get_next_month() end = ck_utils.get_next_month()
session = db.get_session() with db.session_for_read() as session:
q = utils.model_query( q = utils.model_query(
self.frame_model, self.frame_model,
session) session)
# begin and end filters are both needed, do not remove one of them. # begin and end filters are both needed, do not remove one of them.
q = q.filter( q = q.filter(
self.frame_model.begin.between(begin, end), self.frame_model.begin.between(begin, end),
self.frame_model.end.between(begin, end)) self.frame_model.end.between(begin, end))
for filter_name, filter_value in filters.items(): for filter_name, filter_value in filters.items():
if filter_value: if filter_value:
q = q.filter( q = q.filter(
getattr(self.frame_model, filter_name) == filter_value) getattr(self.frame_model, filter_name) == filter_value)
if not filters.get('res_type'): if not filters.get('res_type'):
q = q.filter(self.frame_model.res_type != '_NO_DATA_') q = q.filter(self.frame_model.res_type != '_NO_DATA_')
count = q.count() count = q.count()
if not count: if not count:
raise NoTimeFrame() raise NoTimeFrame()
r = q.all() r = q.all()
return [entry.to_cloudkitty(self._collector) for entry in r] return [entry.to_cloudkitty(self._collector) for entry in r]
def _append_time_frame(self, res_type, frame, tenant_id): def _append_time_frame(self, res_type, frame, tenant_id):
vol_dict = frame['vol'] vol_dict = frame['vol']
@@ -201,4 +192,5 @@ class SQLAlchemyStorage(storage.BaseStorage):
:param desc: Resource description (metadata). :param desc: Resource description (metadata).
""" """
frame = self.frame_model(**kwargs) frame = self.frame_model(**kwargs)
self._session[kwargs.get('tenant_id')].add(frame) with db.session_for_write() as session:
session.add(frame)

View File

@@ -85,28 +85,26 @@ class StateManager(object):
:param offset: optional to shift the projection :param offset: optional to shift the projection
:type offset: int :type offset: int
""" """
session = db.get_session() with db.session_for_read() as session:
session.begin()
q = utils.model_query(self.model, session) q = utils.model_query(self.model, session)
if identifier: if identifier:
q = q.filter( q = q.filter(
self.model.identifier.in_(to_list_if_needed(identifier))) self.model.identifier.in_(to_list_if_needed(identifier)))
if fetcher: if fetcher:
q = q.filter( q = q.filter(
self.model.fetcher.in_(to_list_if_needed(fetcher))) self.model.fetcher.in_(to_list_if_needed(fetcher)))
if collector: if collector:
q = q.filter( q = q.filter(
self.model.collector.in_(to_list_if_needed(collector))) self.model.collector.in_(to_list_if_needed(collector)))
if scope_key: if scope_key:
q = q.filter( q = q.filter(
self.model.scope_key.in_(to_list_if_needed(scope_key))) self.model.scope_key.in_(to_list_if_needed(scope_key)))
if active is not None and active != []: if active is not None and active != []:
q = q.filter(self.model.active.in_(to_list_if_needed(active))) q = q.filter(self.model.active.in_(to_list_if_needed(active)))
q = apply_offset_and_limit(limit, offset, q) q = apply_offset_and_limit(limit, offset, q)
r = q.all() r = q.all()
session.close()
for item in r: for item in r:
item.last_processed_timestamp = tzutils.utc_to_local( item.last_processed_timestamp = tzutils.utc_to_local(
@@ -183,20 +181,18 @@ class StateManager(object):
""" """
last_processed_timestamp = tzutils.local_to_utc( last_processed_timestamp = tzutils.local_to_utc(
last_processed_timestamp, naive=True) last_processed_timestamp, naive=True)
session = db.get_session() with db.session_for_write() as session:
session.begin() r = self._get_db_item(
r = self._get_db_item( session, identifier, fetcher, collector, scope_key)
session, identifier, fetcher, collector, scope_key)
if r: if r:
if r.last_processed_timestamp != last_processed_timestamp: if r.last_processed_timestamp != last_processed_timestamp:
r.last_processed_timestamp = last_processed_timestamp r.last_processed_timestamp = last_processed_timestamp
session.commit() session.commit()
else: else:
self.create_scope(identifier, last_processed_timestamp, self.create_scope(identifier, last_processed_timestamp,
fetcher=fetcher, collector=collector, fetcher=fetcher, collector=collector,
scope_key=scope_key) scope_key=scope_key)
session.close()
def create_scope(self, identifier, last_processed_timestamp, fetcher=None, def create_scope(self, identifier, last_processed_timestamp, fetcher=None,
collector=None, scope_key=None, active=True, collector=None, scope_key=None, active=True,
@@ -219,25 +215,18 @@ class StateManager(object):
:type session: object :type session: object
""" """
is_session_reused = True with db.session_for_write() as session:
if not session:
session = db.get_session()
session.begin()
is_session_reused = False
state_object = self.model( state_object = self.model(
identifier=identifier, identifier=identifier,
last_processed_timestamp=last_processed_timestamp, last_processed_timestamp=last_processed_timestamp,
fetcher=fetcher, fetcher=fetcher,
collector=collector, collector=collector,
scope_key=scope_key, scope_key=scope_key,
active=active active=active
) )
session.add(state_object) session.add(state_object)
session.commit() session.commit()
if not is_session_reused:
session.close()
def get_state(self, identifier, def get_state(self, identifier,
fetcher=None, collector=None, scope_key=None): fetcher=None, collector=None, scope_key=None):
@@ -261,11 +250,9 @@ class StateManager(object):
:type scope_key: str :type scope_key: str
:rtype: datetime.datetime :rtype: datetime.datetime
""" """
session = db.get_session() with db.session_for_read() as session:
session.begin() r = self._get_db_item(
r = self._get_db_item( session, identifier, fetcher, collector, scope_key)
session, identifier, fetcher, collector, scope_key)
session.close()
return tzutils.utc_to_local(r.last_processed_timestamp) if r else None return tzutils.utc_to_local(r.last_processed_timestamp) if r else None
def init(self): def init(self):
@@ -274,10 +261,8 @@ class StateManager(object):
# This is made in order to stay compatible with legacy behavior but # This is made in order to stay compatible with legacy behavior but
# shouldn't be used # shouldn't be used
def get_tenants(self, begin=None, end=None): def get_tenants(self, begin=None, end=None):
session = db.get_session() with db.session_for_read() as session:
session.begin() q = utils.model_query(self.model, session)
q = utils.model_query(self.model, session)
session.close()
return [tenant.identifier for tenant in q] return [tenant.identifier for tenant in q]
def update_storage_scope(self, storage_scope_to_update, scope_key=None, def update_storage_scope(self, storage_scope_to_update, scope_key=None,
@@ -295,30 +280,28 @@ class StateManager(object):
:param active: indicates if the storage scope is active for processing :param active: indicates if the storage scope is active for processing
:type active: bool :type active: bool
""" """
session = db.get_session() with db.session_for_write() as session:
session.begin()
db_scope = self._get_db_item(session, db_scope = self._get_db_item(session,
storage_scope_to_update.identifier, storage_scope_to_update.identifier,
storage_scope_to_update.fetcher, storage_scope_to_update.fetcher,
storage_scope_to_update.collector, storage_scope_to_update.collector,
storage_scope_to_update.scope_key) storage_scope_to_update.scope_key)
if scope_key: if scope_key:
db_scope.scope_key = scope_key db_scope.scope_key = scope_key
if fetcher: if fetcher:
db_scope.fetcher = fetcher db_scope.fetcher = fetcher
if collector: if collector:
db_scope.collector = collector db_scope.collector = collector
if active is not None and active != db_scope.active: if active is not None and active != db_scope.active:
db_scope.active = active db_scope.active = active
now = tzutils.localized_now() now = tzutils.localized_now()
db_scope.scope_activation_toggle_date = tzutils.local_to_utc( db_scope.scope_activation_toggle_date = tzutils.local_to_utc(
now, naive=True) now, naive=True)
session.commit() session.commit()
session.close()
def is_storage_scope_active(self, identifier, fetcher=None, def is_storage_scope_active(self, identifier, fetcher=None,
collector=None, scope_key=None): collector=None, scope_key=None):
@@ -334,11 +317,9 @@ class StateManager(object):
:type scope_key: str :type scope_key: str
:rtype: datetime.datetime :rtype: datetime.datetime
""" """
session = db.get_session() with db.session_for_read() as session:
session.begin() r = self._get_db_item(
r = self._get_db_item( session, identifier, fetcher, collector, scope_key)
session, identifier, fetcher, collector, scope_key)
session.close()
return r.active return r.active
@@ -365,23 +346,21 @@ class ReprocessingSchedulerDb(object):
projection. The ordering field will be the `id`. projection. The ordering field will be the `id`.
:type order: str :type order: str
""" """
session = db.get_session() with db.session_for_read() as session:
session.begin()
query = utils.model_query(self.model, session) query = utils.model_query(self.model, session)
if identifier: if identifier:
query = query.filter(self.model.identifier.in_(identifier)) query = query.filter(self.model.identifier.in_(identifier))
if remove_finished: if remove_finished:
query = self.remove_finished_processing_schedules(query) query = self.remove_finished_processing_schedules(query)
if order: if order:
query = query.order_by(sql.text("id %s" % order)) query = query.order_by(sql.text("id %s" % order))
query = apply_offset_and_limit(limit, offset, query) query = apply_offset_and_limit(limit, offset, query)
result_set = query.all() result_set = query.all()
session.close()
return result_set return result_set
def remove_finished_processing_schedules(self, query): def remove_finished_processing_schedules(self, query):
@@ -398,13 +377,10 @@ class ReprocessingSchedulerDb(object):
:type reprocessing_scheduler: models.ReprocessingScheduler :type reprocessing_scheduler: models.ReprocessingScheduler
""" """
session = db.get_session() with db.session_for_write() as session:
session.begin()
session.add(reprocessing_scheduler) session.add(reprocessing_scheduler)
session.commit() session.commit()
session.close()
def get_from_db(self, identifier=None, start_reprocess_time=None, def get_from_db(self, identifier=None, start_reprocess_time=None,
end_reprocess_time=None): end_reprocess_time=None):
@@ -419,12 +395,10 @@ class ReprocessingSchedulerDb(object):
reprocessing schedule reprocessing schedule
:type end_reprocess_time: datetime.datetime :type end_reprocess_time: datetime.datetime
""" """
session = db.get_session() with db.session_for_read() as session:
session.begin()
result_set = self._get_db_item( result_set = self._get_db_item(
end_reprocess_time, identifier, session, start_reprocess_time) end_reprocess_time, identifier, session, start_reprocess_time)
session.close()
return result_set return result_set
@@ -459,22 +433,21 @@ class ReprocessingSchedulerDb(object):
:type new_current_time_stamp: datetime.datetime :type new_current_time_stamp: datetime.datetime
""" """
session = db.get_session() with db.session_for_write() as session:
session.begin()
result_set = self._get_db_item( result_set = self._get_db_item(
end_reprocess_time, identifier, session, start_reprocess_time) end_reprocess_time, identifier, session, start_reprocess_time)
if not result_set: if not result_set:
LOG.warning("Trying to update current time to [%s] for identifier " LOG.warning("Trying to update current time to [%s] for "
"[%s] and reprocessing range [start=%, end=%s], but " "identifier [%s] and reprocessing range [start=%, "
"we could not find a this task in the database.", "end=%s], but we could not find a this task in the"
new_current_time_stamp, identifier, " database.",
start_reprocess_time, end_reprocess_time) new_current_time_stamp, identifier,
return start_reprocess_time, end_reprocess_time)
new_current_time_stamp = tzutils.local_to_utc( return
new_current_time_stamp, naive=True) new_current_time_stamp = tzutils.local_to_utc(
new_current_time_stamp, naive=True)
result_set.current_reprocess_time = new_current_time_stamp result_set.current_reprocess_time = new_current_time_stamp
session.commit() session.commit()
session.close()

View File

@@ -90,8 +90,10 @@ class TestCase(testscenarios.TestWithScenarios, base.BaseTestCase):
self.app_context.push() self.app_context.push()
def tearDown(self): def tearDown(self):
db.get_engine().dispose() with db.session_for_write() as session:
self.auth.stop() engine = session.get_bind()
self.session.stop() engine.dispose()
self.app_context.pop() self.auth.stop()
super(TestCase, self).tearDown() self.session.stop()
self.app_context.pop()
super(TestCase, self).tearDown()

View File

@@ -217,7 +217,9 @@ class ConfigFixture(fixture.GabbiFixture):
def stop_fixture(self): def stop_fixture(self):
if self.conf: if self.conf:
self.conf.reset() self.conf.reset()
db.get_engine().dispose() with db.session_for_write() as session:
engine = session.get_bind()
engine.dispose()
class ConfigFixtureStorageV2(ConfigFixture): class ConfigFixtureStorageV2(ConfigFixture):
@@ -341,11 +343,11 @@ class BaseStorageDataFixture(fixture.GabbiFixture):
def stop_fixture(self): def stop_fixture(self):
model = models.RatedDataFrame model = models.RatedDataFrame
session = db.get_session() with db.session_for_write() as session:
q = utils.model_query( q = utils.model_query(
model, model,
session) session)
q.delete() q.delete()
class StorageDataFixture(BaseStorageDataFixture): class StorageDataFixture(BaseStorageDataFixture):
@@ -405,11 +407,11 @@ class ScopeStateFixture(fixture.GabbiFixture):
d[0], d[1], fetcher=d[2], collector=d[3], scope_key=d[4]) d[0], d[1], fetcher=d[2], collector=d[3], scope_key=d[4])
def stop_fixture(self): def stop_fixture(self):
session = db.get_session() with db.session_for_write() as session:
q = utils.model_query( q = utils.model_query(
self.sm.model, self.sm.model,
session) session)
q.delete() q.delete()
class CORSConfigFixture(fixture.GabbiFixture): class CORSConfigFixture(fixture.GabbiFixture):

View File

@@ -28,7 +28,8 @@ class StateManagerTest(tests.TestCase):
``filter()`` can be called any number of times, followed by first(), ``filter()`` can be called any number of times, followed by first(),
which will cycle over the ``output`` parameter passed to the which will cycle over the ``output`` parameter passed to the
constructor. The ``first_called`` attributes constructor. The ``first_called`` attribute tracks how many times
first() is called.
""" """
def __init__(self, output, *args, **kwargs): def __init__(self, output, *args, **kwargs):
super(StateManagerTest.QueryMock, self).__init__(*args, **kwargs) super(StateManagerTest.QueryMock, self).__init__(*args, **kwargs)
@@ -80,7 +81,7 @@ class StateManagerTest(tests.TestCase):
self._test_x_state_does_update_columns(self._state.get_state) self._test_x_state_does_update_columns(self._state.get_state)
def test_set_state_does_update_columns(self): def test_set_state_does_update_columns(self):
with mock.patch('cloudkitty.db.get_session'): with mock.patch('cloudkitty.db.session_for_write'):
self._test_x_state_does_update_columns( self._test_x_state_does_update_columns(
lambda x: self._state.set_state(x, datetime(2042, 1, 1))) lambda x: self._state.set_state(x, datetime(2042, 1, 1)))
@@ -101,7 +102,7 @@ class StateManagerTest(tests.TestCase):
self._test_x_state_no_column_update(self._state.get_state) self._test_x_state_no_column_update(self._state.get_state)
def test_set_state_no_column_update(self): def test_set_state_no_column_update(self):
with mock.patch('cloudkitty.db.get_session'): with mock.patch('cloudkitty.db.session_for_write'):
self._test_x_state_no_column_update( self._test_x_state_no_column_update(
lambda x: self._state.set_state(x, datetime(2042, 1, 1))) lambda x: self._state.set_state(x, datetime(2042, 1, 1)))
@@ -111,8 +112,10 @@ class StateManagerTest(tests.TestCase):
self._get_r_mock('a', 'b', 'c', state)) self._get_r_mock('a', 'b', 'c', state))
with mock.patch( with mock.patch(
'oslo_db.sqlalchemy.utils.model_query', 'oslo_db.sqlalchemy.utils.model_query',
new=query_mock), mock.patch('cloudkitty.db.get_session') as sm: new=query_mock), mock.patch(
sm.return_value = session_mock = mock.MagicMock() 'cloudkitty.db.session_for_write') as sm:
sm.return_value.__enter__.return_value = session_mock = \
mock.MagicMock()
self._state.set_state('fake_identifier', state) self._state.set_state('fake_identifier', state)
session_mock.commit.assert_not_called() session_mock.commit.assert_not_called()
session_mock.add.assert_not_called() session_mock.add.assert_not_called()
@@ -123,8 +126,10 @@ class StateManagerTest(tests.TestCase):
new_state = datetime(2042, 1, 1) new_state = datetime(2042, 1, 1)
with mock.patch( with mock.patch(
'oslo_db.sqlalchemy.utils.model_query', 'oslo_db.sqlalchemy.utils.model_query',
new=query_mock), mock.patch('cloudkitty.db.get_session') as sm: new=query_mock), mock.patch(
sm.return_value = session_mock = mock.MagicMock() 'cloudkitty.db.session_for_write') as sm:
sm.return_value.__enter__.return_value = session_mock = \
mock.MagicMock()
self.assertNotEqual(r_mock.state, new_state) self.assertNotEqual(r_mock.state, new_state)
self._state.set_state('fake_identifier', new_state) self._state.set_state('fake_identifier', new_state)
self.assertEqual(r_mock.last_processed_timestamp, new_state) self.assertEqual(r_mock.last_processed_timestamp, new_state)