Merge "Add pluggable range functions for token flush"

This commit is contained in:
Jenkins 2014-08-04 01:36:29 +00:00 committed by Gerrit Code Review
commit 76090ec1fc
2 changed files with 126 additions and 48 deletions

View File

@ -13,6 +13,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import functools
import uuid
import mock
@ -20,6 +21,7 @@ from oslo.db import exception as db_exception
from oslo.db import options
import sqlalchemy
from sqlalchemy import exc
from testtools import matchers
from keystone.common import sql
from keystone import config
@ -342,37 +344,69 @@ class SqlToken(SqlTests, test_backend.TokenTests):
tok = token_sql.Token()
tok.flush_expired_tokens()
self.assertFalse(mock_sql.get_session().query().filter().limit.called)
filter_mock = mock_sql.get_session().query().filter()
self.assertFalse(filter_mock.limit.called)
self.assertTrue(filter_mock.delete.called_once)
def test_flush_expired_tokens_batch_ibm_db_sa(self):
# TODO(dstanek): This test should be rewritten to be less
# brittle. The code will likely need to be changed first. I
# just copied the spirit of the existing test when I rewrote
# mox -> mock. These tests are brittle because they have the
# call structure for SQLAlchemy encoded in them.
# test ibm_db_sa
def test_flush_expired_tokens_batch_mysql(self):
# test mysql dialect, we don't need to test IBM DB SA separately, since
# other tests below test the differences between how they use the batch
# strategy
with mock.patch.object(token_sql, 'sql') as mock_sql:
# NOTE(dstanek): this will allow us to break out of the
# 'while True' loop
mock_sql.get_session().query().filter().delete.return_value = 0
mock_sql.get_session().bind.dialect.name = 'ibm_db_sa'
mock_sql.get_session().bind.dialect.name = 'mysql'
tok = token_sql.Token()
expiry_mock = mock.Mock()
ITERS = [1, 2, 3]
expiry_mock.return_value = iter(ITERS)
token_sql._expiry_range_batched = expiry_mock
tok.flush_expired_tokens()
mock_limit = mock_sql.get_session().query().filter().limit
mock_limit.assert_called_with(100)
# The expiry strategy is only invoked once, the other calls are via
# the yield return.
expiry_mock.assert_called_once()
mock_delete = mock_sql.get_session().query().filter().delete
self.assertThat(mock_delete.call_args_list,
matchers.HasLength(len(ITERS)))
def test_token_flush_batch_size_default(self):
tok = token_sql.Token()
sqlite_batch = tok.token_flush_batch_size('sqlite')
self.assertEqual(0, sqlite_batch)
def test_expiry_range_batched(self):
upper_bound_mock = mock.Mock(side_effect=[1, "final value"])
sess_mock = mock.Mock()
query_mock = sess_mock.query().filter().order_by().offset().limit()
query_mock.one.side_effect = [['test'], sql.NotFound()]
for i, x in enumerate(token_sql._expiry_range_batched(sess_mock,
upper_bound_mock,
batch_size=50)):
if i == 0:
# The first time the batch iterator returns, it should return
# the first result that comes back from the database.
self.assertEqual(x, 'test')
elif i == 1:
# The second time, the database range function should return
# nothing, so the batch iterator returns the result of the
# upper_bound function
self.assertEqual(x, "final value")
else:
self.fail("range batch function returned more than twice")
def test_token_flush_batch_size_db2(self):
def test_expiry_range_strategy_sqlite(self):
tok = token_sql.Token()
db2_batch = tok.token_flush_batch_size('ibm_db_sa')
self.assertEqual(100, db2_batch)
sqlite_strategy = tok._expiry_range_strategy('sqlite')
self.assertEqual(token_sql._expiry_range_all, sqlite_strategy)
def test_expiry_range_strategy_ibm_db_sa(self):
tok = token_sql.Token()
db2_strategy = tok._expiry_range_strategy('ibm_db_sa')
self.assertIsInstance(db2_strategy, functools.partial)
self.assertEqual(db2_strategy.func, token_sql._expiry_range_batched)
self.assertEqual(db2_strategy.keywords, {'batch_size': 100})
def test_expiry_range_strategy_mysql(self):
tok = token_sql.Token()
mysql_strategy = tok._expiry_range_strategy('mysql')
self.assertIsInstance(mysql_strategy, functools.partial)
self.assertEqual(mysql_strategy.func, token_sql._expiry_range_batched)
self.assertEqual(mysql_strategy.keywords, {'batch_size': 1000})
class SqlCatalog(SqlTests, test_backend.CatalogTests):

View File

@ -13,16 +13,20 @@
# under the License.
import copy
import functools
from keystone.common import sql
from keystone import config
from keystone import exception
from keystone.i18n import _LI
from keystone.openstack.common import log
from keystone.openstack.common import timeutils
from keystone import token
from keystone.token import provider
CONF = config.CONF
LOG = log.getLogger(__name__)
class TokenModel(sql.ModelBase, sql.DictBase):
@ -40,6 +44,42 @@ class TokenModel(sql.ModelBase, sql.DictBase):
)
def _expiry_range_batched(session, upper_bound_func, batch_size):
"""Returns the stop point of the next batch for expiration.
Return the timestamp of the next token that is `batch_size` rows from
being the oldest expired token.
"""
# This expiry strategy splits the tokens into roughly equal sized batches
# to be deleted. It does this by finding the timestamp of a token
# `batch_size` rows from the oldest token and yielding that to the caller.
# It's expected that the caller will then delete all rows with a timestamp
# equal to or older than the one yielded. This may delete slightly more
# tokens than the batch_size, but that should be ok in almost all cases.
LOG.info(_LI('Token expiration batch size: %d') % batch_size)
query = session.query(TokenModel.expires)
query = query.filter(TokenModel.expires < upper_bound_func())
query = query.order_by(TokenModel.expires)
query = query.offset(batch_size - 1)
query = query.limit(1)
while True:
try:
next_expiration = query.one()[0]
except sql.NotFound:
# There are less than `batch_size` rows remaining, so fall
# through to the normal delete
break
yield next_expiration
yield upper_bound_func()
def _expiry_range_all(session, upper_bound_func):
"""Expires all tokens in one pass."""
yield upper_bound_func()
class Token(token.persistence.Driver):
# Public interface
def get_token(self, token_id):
@ -192,41 +232,45 @@ class Token(token.persistence.Driver):
tokens.append(record)
return tokens
def token_flush_batch_size(self, dialect):
batch_size = 0
def _expiry_range_strategy(self, dialect):
"""Choose a token range expiration strategy
Based on the DB dialect, select a expiry range callable that is
appropriate.
"""
# DB2 and MySQL can both benefit from a batched strategy. On DB2 the
# transaction log can fill up and on MySQL w/Galera, large
# transactions can exceed the maximum write set size.
if dialect == 'ibm_db_sa':
# This functionality is limited to DB2, because
# it is necessary to prevent the tranaction log
# from filling up, whereas at least some of the
# other supported databases do not support update
# queries with LIMIT subqueries nor do they appear
# to require the use of such queries when deleting
# large numbers of records at once.
batch_size = 100
# Limit of 100 is known to not fill a transaction log
# of default maximum size while not significantly
# impacting the performance of large token purges on
# systems where the maximum transaction log size has
# been increased beyond the default.
return batch_size
return functools.partial(_expiry_range_batched,
batch_size=100)
elif dialect == 'mysql':
# We want somewhat more than 100, since Galera replication delay is
# at least RTT*2. This can be a significant amount of time if
# doing replication across a WAN.
return functools.partial(_expiry_range_batched,
batch_size=1000)
return _expiry_range_all
def flush_expired_tokens(self):
session = sql.get_session()
dialect = session.bind.dialect.name
batch_size = self.token_flush_batch_size(dialect)
if batch_size > 0:
query = session.query(TokenModel.id)
query = query.filter(TokenModel.expires < timeutils.utcnow())
query = query.limit(batch_size).subquery()
delete_query = (session.query(TokenModel).
filter(TokenModel.id.in_(query)))
while True:
rowcount = delete_query.delete(synchronize_session=False)
if rowcount == 0:
break
else:
query = session.query(TokenModel)
query = query.filter(TokenModel.expires < timeutils.utcnow())
query.delete(synchronize_session=False)
expiry_range_func = self._expiry_range_strategy(dialect)
query = session.query(TokenModel.expires)
total_removed = 0
upper_bound_func = timeutils.utcnow
for expiry_time in expiry_range_func(session, upper_bound_func):
delete_query = query.filter(TokenModel.expires <
expiry_time)
row_count = delete_query.delete(synchronize_session=False)
total_removed += row_count
LOG.debug('Removed %d expired tokens', total_removed)
session.flush()
LOG.info(_LI('Total expired tokens removed: %d'), total_removed)