Refactored session access
Encapsulated all access to sqlalchemy.get_session This is needed for further refactoring of session handling Change-Id: I35c64b69c7bde80713ca1f0e7f7f83bd57493668
This commit is contained in:
		@@ -176,13 +176,15 @@ class Connection(base.Connection):
 | 
			
		||||
            conf.database.connection = \
 | 
			
		||||
                os.environ.get('CEILOMETER_TEST_SQL_URL', url)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _get_db_session():
 | 
			
		||||
        return sqlalchemy_session.get_session()
 | 
			
		||||
 | 
			
		||||
    def upgrade(self):
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        migration.db_sync(session.get_bind())
 | 
			
		||||
        migration.db_sync(self._get_db_session().get_bind())
 | 
			
		||||
 | 
			
		||||
    def clear(self):
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        engine = session.get_bind()
 | 
			
		||||
        engine = self._get_db_session().get_bind()
 | 
			
		||||
        for table in reversed(models.Base.metadata.sorted_tables):
 | 
			
		||||
            engine.execute(table.delete())
 | 
			
		||||
 | 
			
		||||
@@ -221,27 +223,26 @@ class Connection(base.Connection):
 | 
			
		||||
            setattr(obj, k, kwargs[k])
 | 
			
		||||
        return obj
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def record_metering_data(cls, data):
 | 
			
		||||
    def record_metering_data(self, data):
 | 
			
		||||
        """Write the data to the backend storage system.
 | 
			
		||||
 | 
			
		||||
        :param data: a dictionary such as returned by
 | 
			
		||||
                     ceilometer.meter.meter_message_from_counter
 | 
			
		||||
        """
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        session = self._get_db_session()
 | 
			
		||||
        with session.begin():
 | 
			
		||||
            # Record the updated resource metadata
 | 
			
		||||
            rmetadata = data['resource_metadata']
 | 
			
		||||
            source = cls._create_or_update(session, models.Source,
 | 
			
		||||
                                           data['source'])
 | 
			
		||||
            user = cls._create_or_update(session, models.User, data['user_id'],
 | 
			
		||||
                                         source)
 | 
			
		||||
            project = cls._create_or_update(session, models.Project,
 | 
			
		||||
                                            data['project_id'], source)
 | 
			
		||||
            resource = cls._create_or_update(session, models.Resource,
 | 
			
		||||
                                             data['resource_id'], source,
 | 
			
		||||
                                             user=user, project=project,
 | 
			
		||||
                                             resource_metadata=rmetadata)
 | 
			
		||||
            source = self._create_or_update(session, models.Source,
 | 
			
		||||
                                            data['source'])
 | 
			
		||||
            user = self._create_or_update(session, models.User,
 | 
			
		||||
                                          data['user_id'], source)
 | 
			
		||||
            project = self._create_or_update(session, models.Project,
 | 
			
		||||
                                             data['project_id'], source)
 | 
			
		||||
            resource = self._create_or_update(session, models.Resource,
 | 
			
		||||
                                              data['resource_id'], source,
 | 
			
		||||
                                              user=user, project=project,
 | 
			
		||||
                                              resource_metadata=rmetadata)
 | 
			
		||||
 | 
			
		||||
            # Record the raw data for the meter.
 | 
			
		||||
            meter = models.Meter(counter_type=data['counter_type'],
 | 
			
		||||
@@ -273,8 +274,7 @@ class Connection(base.Connection):
 | 
			
		||||
                                               meta_key=key,
 | 
			
		||||
                                               value=v))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def clear_expired_metering_data(ttl):
 | 
			
		||||
    def clear_expired_metering_data(self, ttl):
 | 
			
		||||
        """Clear expired data from the backend storage system according to the
 | 
			
		||||
        time-to-live.
 | 
			
		||||
 | 
			
		||||
@@ -282,7 +282,7 @@ class Connection(base.Connection):
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        session = self._get_db_session()
 | 
			
		||||
        with session.begin():
 | 
			
		||||
            end = timeutils.utcnow() - datetime.timedelta(seconds=ttl)
 | 
			
		||||
            meter_query = session.query(models.Meter)\
 | 
			
		||||
@@ -320,32 +320,27 @@ class Connection(base.Connection):
 | 
			
		||||
            for res_obj in query.all():
 | 
			
		||||
                session.delete(res_obj)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_users(source=None):
 | 
			
		||||
    def get_users(self, source=None):
 | 
			
		||||
        """Return an iterable of user id strings.
 | 
			
		||||
 | 
			
		||||
        :param source: Optional source filter.
 | 
			
		||||
        """
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        query = session.query(models.User.id)
 | 
			
		||||
        query = self._get_db_session().query(models.User.id)
 | 
			
		||||
        if source is not None:
 | 
			
		||||
            query = query.filter(models.User.sources.any(id=source))
 | 
			
		||||
        return (x[0] for x in query.all())
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_projects(source=None):
 | 
			
		||||
    def get_projects(self, source=None):
 | 
			
		||||
        """Return an iterable of project id strings.
 | 
			
		||||
 | 
			
		||||
        :param source: Optional source filter.
 | 
			
		||||
        """
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        query = session.query(models.Project.id)
 | 
			
		||||
        query = self._get_db_session().query(models.Project.id)
 | 
			
		||||
        if source:
 | 
			
		||||
            query = query.filter(models.Project.sources.any(id=source))
 | 
			
		||||
        return (x[0] for x in query.all())
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_resources(user=None, project=None, source=None,
 | 
			
		||||
    def get_resources(self, user=None, project=None, source=None,
 | 
			
		||||
                      start_timestamp=None, start_timestamp_op=None,
 | 
			
		||||
                      end_timestamp=None, end_timestamp_op=None,
 | 
			
		||||
                      metaquery={}, resource=None, pagination=None):
 | 
			
		||||
@@ -369,11 +364,11 @@ class Connection(base.Connection):
 | 
			
		||||
        if pagination:
 | 
			
		||||
            raise NotImplementedError(_('Pagination not implemented'))
 | 
			
		||||
 | 
			
		||||
        session = self._get_db_session()
 | 
			
		||||
 | 
			
		||||
        # (thomasm) We need to get the max timestamp first, since that's the
 | 
			
		||||
        # most accurate. We also need to filter down in the subquery to
 | 
			
		||||
        # constrain what we have to JOIN on later.
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
 | 
			
		||||
        ts_subquery = session.query(
 | 
			
		||||
            models.Meter.resource_id,
 | 
			
		||||
            func.max(models.Meter.timestamp).label("max_ts"),
 | 
			
		||||
@@ -447,8 +442,7 @@ class Connection(base.Connection):
 | 
			
		||||
                metadata=meter.resource_metadata,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_meters(user=None, project=None, resource=None, source=None,
 | 
			
		||||
    def get_meters(self, user=None, project=None, resource=None, source=None,
 | 
			
		||||
                   metaquery={}, pagination=None):
 | 
			
		||||
        """Return an iterable of api_models.Meter instances
 | 
			
		||||
 | 
			
		||||
@@ -463,7 +457,7 @@ class Connection(base.Connection):
 | 
			
		||||
        if pagination:
 | 
			
		||||
            raise NotImplementedError(_('Pagination not implemented'))
 | 
			
		||||
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        session = self._get_db_session()
 | 
			
		||||
 | 
			
		||||
        # Meter table will store large records and join with resource
 | 
			
		||||
        # will be very slow.
 | 
			
		||||
@@ -514,8 +508,7 @@ class Connection(base.Connection):
 | 
			
		||||
                source=resource.sources[0].id,
 | 
			
		||||
                user_id=resource.user_id)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_samples(sample_filter, limit=None):
 | 
			
		||||
    def get_samples(self, sample_filter, limit=None):
 | 
			
		||||
        """Return an iterable of api_models.Samples.
 | 
			
		||||
 | 
			
		||||
        :param sample_filter: Filter.
 | 
			
		||||
@@ -524,7 +517,7 @@ class Connection(base.Connection):
 | 
			
		||||
        if limit == 0:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        session = self._get_db_session()
 | 
			
		||||
        query = session.query(models.Meter)
 | 
			
		||||
        query = make_query_from_filter(session, query, sample_filter,
 | 
			
		||||
                                       require_meter=False)
 | 
			
		||||
@@ -555,8 +548,7 @@ class Connection(base.Connection):
 | 
			
		||||
                message_signature=s.message_signature,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _make_stats_query(sample_filter, groupby):
 | 
			
		||||
    def _make_stats_query(self, sample_filter, groupby):
 | 
			
		||||
        select = [
 | 
			
		||||
            models.Meter.counter_unit.label('unit'),
 | 
			
		||||
            func.min(models.Meter.timestamp).label('tsmin'),
 | 
			
		||||
@@ -568,7 +560,7 @@ class Connection(base.Connection):
 | 
			
		||||
            func.count(models.Meter.counter_volume).label('count'),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        session = self._get_db_session()
 | 
			
		||||
 | 
			
		||||
        if groupby:
 | 
			
		||||
            group_attributes = [getattr(models.Meter, g) for g in groupby]
 | 
			
		||||
@@ -683,7 +675,7 @@ class Connection(base.Connection):
 | 
			
		||||
        if pagination:
 | 
			
		||||
            raise NotImplementedError(_('Pagination not implemented'))
 | 
			
		||||
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        session = self._get_db_session()
 | 
			
		||||
        query = session.query(models.Alarm)
 | 
			
		||||
        if name is not None:
 | 
			
		||||
            query = query.filter(models.Alarm.name == name)
 | 
			
		||||
@@ -703,7 +695,7 @@ class Connection(base.Connection):
 | 
			
		||||
 | 
			
		||||
        :param alarm: The alarm to create.
 | 
			
		||||
        """
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        session = self._get_db_session()
 | 
			
		||||
        with session.begin():
 | 
			
		||||
            alarm_row = models.Alarm(id=alarm.alarm_id)
 | 
			
		||||
            alarm_row.update(alarm.as_dict())
 | 
			
		||||
@@ -716,7 +708,7 @@ class Connection(base.Connection):
 | 
			
		||||
 | 
			
		||||
        :param alarm: the new Alarm to update
 | 
			
		||||
        """
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        session = self._get_db_session()
 | 
			
		||||
        with session.begin():
 | 
			
		||||
            Connection._create_or_update(session, models.User,
 | 
			
		||||
                                         alarm.user_id)
 | 
			
		||||
@@ -727,13 +719,12 @@ class Connection(base.Connection):
 | 
			
		||||
 | 
			
		||||
        return self._row_to_alarm_model(alarm_row)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def delete_alarm(alarm_id):
 | 
			
		||||
    def delete_alarm(self, alarm_id):
 | 
			
		||||
        """Delete a alarm
 | 
			
		||||
 | 
			
		||||
        :param alarm_id: ID of the alarm to delete
 | 
			
		||||
        """
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        session = self._get_db_session()
 | 
			
		||||
        with session.begin():
 | 
			
		||||
            session.query(models.Alarm).filter(
 | 
			
		||||
                models.Alarm.id == alarm_id).delete()
 | 
			
		||||
@@ -776,7 +767,7 @@ class Connection(base.Connection):
 | 
			
		||||
        :param end_timestamp: Optional modified timestamp end range
 | 
			
		||||
        :param end_timestamp_op: Optional timestamp end range operation
 | 
			
		||||
        """
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        session = self._get_db_session()
 | 
			
		||||
        query = session.query(models.AlarmChange)
 | 
			
		||||
        query = query.filter(models.AlarmChange.alarm_id == alarm_id)
 | 
			
		||||
 | 
			
		||||
@@ -810,7 +801,7 @@ class Connection(base.Connection):
 | 
			
		||||
    def record_alarm_change(self, alarm_change):
 | 
			
		||||
        """Record alarm change event.
 | 
			
		||||
        """
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        session = self._get_db_session()
 | 
			
		||||
        with session.begin():
 | 
			
		||||
            Connection._create_or_update(session, models.User,
 | 
			
		||||
                                         alarm_change['user_id'])
 | 
			
		||||
@@ -823,13 +814,12 @@ class Connection(base.Connection):
 | 
			
		||||
            alarm_change_row.update(alarm_change)
 | 
			
		||||
            session.add(alarm_change_row)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _get_or_create_trait_type(trait_type, data_type, session=None):
 | 
			
		||||
    def _get_or_create_trait_type(self, trait_type, data_type, session=None):
 | 
			
		||||
        """Find if this trait already exists in the database, and
 | 
			
		||||
        if it does not, create a new entry in the trait type table.
 | 
			
		||||
        """
 | 
			
		||||
        if session is None:
 | 
			
		||||
            session = sqlalchemy_session.get_session()
 | 
			
		||||
            session = self._get_db_session()
 | 
			
		||||
        with session.begin(subtransactions=True):
 | 
			
		||||
            tt = session.query(models.TraitType).filter(
 | 
			
		||||
                models.TraitType.desc == trait_type,
 | 
			
		||||
@@ -839,15 +829,14 @@ class Connection(base.Connection):
 | 
			
		||||
                session.add(tt)
 | 
			
		||||
        return tt
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _make_trait(cls, trait_model, event, session=None):
 | 
			
		||||
    def _make_trait(self, trait_model, event, session=None):
 | 
			
		||||
        """Make a new Trait from a Trait model.
 | 
			
		||||
 | 
			
		||||
        Doesn't flush or add to session.
 | 
			
		||||
        """
 | 
			
		||||
        trait_type = cls._get_or_create_trait_type(trait_model.name,
 | 
			
		||||
                                                   trait_model.dtype,
 | 
			
		||||
                                                   session)
 | 
			
		||||
        trait_type = self._get_or_create_trait_type(trait_model.name,
 | 
			
		||||
                                                    trait_model.dtype,
 | 
			
		||||
                                                    session)
 | 
			
		||||
        value_map = models.Trait._value_map
 | 
			
		||||
        values = {'t_string': None, 't_float': None,
 | 
			
		||||
                  't_int': None, 't_datetime': None}
 | 
			
		||||
@@ -855,15 +844,14 @@ class Connection(base.Connection):
 | 
			
		||||
        values[value_map[trait_model.dtype]] = value
 | 
			
		||||
        return models.Trait(trait_type, event, **values)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _get_or_create_event_type(event_type, session=None):
 | 
			
		||||
    def _get_or_create_event_type(self, event_type, session=None):
 | 
			
		||||
        """Here, we check to see if an event type with the supplied
 | 
			
		||||
        name already exists. If not, we create it and return the record.
 | 
			
		||||
 | 
			
		||||
        This may result in a flush.
 | 
			
		||||
        """
 | 
			
		||||
        if session is None:
 | 
			
		||||
            session = sqlalchemy_session.get_session()
 | 
			
		||||
            session = self._get_db_session()
 | 
			
		||||
        with session.begin(subtransactions=True):
 | 
			
		||||
            et = session.query(models.EventType).filter(
 | 
			
		||||
                models.EventType.desc == event_type).first()
 | 
			
		||||
@@ -872,13 +860,12 @@ class Connection(base.Connection):
 | 
			
		||||
                session.add(et)
 | 
			
		||||
        return et
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _record_event(cls, session, event_model):
 | 
			
		||||
    def _record_event(self, session, event_model):
 | 
			
		||||
        """Store a single Event, including related Traits.
 | 
			
		||||
        """
 | 
			
		||||
        with session.begin(subtransactions=True):
 | 
			
		||||
            event_type = cls._get_or_create_event_type(event_model.event_type,
 | 
			
		||||
                                                       session=session)
 | 
			
		||||
            event_type = self._get_or_create_event_type(event_model.event_type,
 | 
			
		||||
                                                        session=session)
 | 
			
		||||
 | 
			
		||||
            event = models.Event(event_model.message_id, event_type,
 | 
			
		||||
                                 event_model.generated)
 | 
			
		||||
@@ -887,7 +874,7 @@ class Connection(base.Connection):
 | 
			
		||||
            new_traits = []
 | 
			
		||||
            if event_model.traits:
 | 
			
		||||
                for trait in event_model.traits:
 | 
			
		||||
                    t = cls._make_trait(trait, event, session=session)
 | 
			
		||||
                    t = self._make_trait(trait, event, session=session)
 | 
			
		||||
                    session.add(t)
 | 
			
		||||
                    new_traits.append(t)
 | 
			
		||||
 | 
			
		||||
@@ -907,7 +894,7 @@ class Connection(base.Connection):
 | 
			
		||||
        Flush when they're all added, unless new EventTypes or
 | 
			
		||||
        TraitTypes are added along the way.
 | 
			
		||||
        """
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        session = self._get_db_session()
 | 
			
		||||
        events = []
 | 
			
		||||
        problem_events = []
 | 
			
		||||
        for event_model in event_models:
 | 
			
		||||
@@ -933,7 +920,7 @@ class Connection(base.Connection):
 | 
			
		||||
 | 
			
		||||
        start = event_filter.start_time
 | 
			
		||||
        end = event_filter.end_time
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        session = self._get_db_session()
 | 
			
		||||
        LOG.debug(_("Getting events that match filter: %s") % event_filter)
 | 
			
		||||
        with session.begin():
 | 
			
		||||
            event_query = session.query(models.Event)
 | 
			
		||||
@@ -1030,12 +1017,11 @@ class Connection(base.Connection):
 | 
			
		||||
        event_models = event_models_dict.values()
 | 
			
		||||
        return sorted(event_models, key=operator.attrgetter('generated'))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_event_types():
 | 
			
		||||
    def get_event_types(self):
 | 
			
		||||
        """Return all event types as an iterable of strings.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        session = self._get_db_session()
 | 
			
		||||
        with session.begin():
 | 
			
		||||
            query = session.query(models.EventType.desc)\
 | 
			
		||||
                .order_by(models.EventType.desc)
 | 
			
		||||
@@ -1043,15 +1029,14 @@ class Connection(base.Connection):
 | 
			
		||||
                # The query returns a tuple with one element.
 | 
			
		||||
                yield name[0]
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_trait_types(event_type):
 | 
			
		||||
    def get_trait_types(self, event_type):
 | 
			
		||||
        """Return a dictionary containing the name and data type of
 | 
			
		||||
        the trait type. Only trait types for the provided event_type are
 | 
			
		||||
        returned.
 | 
			
		||||
 | 
			
		||||
        :param event_type: the type of the Event
 | 
			
		||||
        """
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        session = self._get_db_session()
 | 
			
		||||
 | 
			
		||||
        LOG.debug(_("Get traits for %s") % event_type)
 | 
			
		||||
        with session.begin():
 | 
			
		||||
@@ -1075,8 +1060,7 @@ class Connection(base.Connection):
 | 
			
		||||
            for desc, type in query.all():
 | 
			
		||||
                yield {'name': desc, 'data_type': type}
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_traits(event_type, trait_type=None):
 | 
			
		||||
    def get_traits(self, event_type, trait_type=None):
 | 
			
		||||
        """Return all trait instances associated with an event_type. If
 | 
			
		||||
        trait_type is specified, only return instances of that trait type.
 | 
			
		||||
 | 
			
		||||
@@ -1084,7 +1068,7 @@ class Connection(base.Connection):
 | 
			
		||||
        :param trait_type: the name of the Trait to filter by
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        session = sqlalchemy_session.get_session()
 | 
			
		||||
        session = self._get_db_session()
 | 
			
		||||
        with session.begin():
 | 
			
		||||
            trait_type_filters = [models.TraitType.id ==
 | 
			
		||||
                                  models.Trait.trait_type_id]
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user