Merge "Add pluggable range functions for token flush"
This commit is contained in:
commit
76090ec1fc
@ -13,6 +13,7 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import functools
|
||||
import uuid
|
||||
|
||||
import mock
|
||||
@ -20,6 +21,7 @@ from oslo.db import exception as db_exception
|
||||
from oslo.db import options
|
||||
import sqlalchemy
|
||||
from sqlalchemy import exc
|
||||
from testtools import matchers
|
||||
|
||||
from keystone.common import sql
|
||||
from keystone import config
|
||||
@ -342,37 +344,69 @@ class SqlToken(SqlTests, test_backend.TokenTests):
|
||||
tok = token_sql.Token()
|
||||
tok.flush_expired_tokens()
|
||||
|
||||
self.assertFalse(mock_sql.get_session().query().filter().limit.called)
|
||||
filter_mock = mock_sql.get_session().query().filter()
|
||||
self.assertFalse(filter_mock.limit.called)
|
||||
self.assertTrue(filter_mock.delete.called_once)
|
||||
|
||||
def test_flush_expired_tokens_batch_ibm_db_sa(self):
|
||||
# TODO(dstanek): This test should be rewritten to be less
|
||||
# brittle. The code will likely need to be changed first. I
|
||||
# just copied the spirit of the existing test when I rewrote
|
||||
# mox -> mock. These tests are brittle because they have the
|
||||
# call structure for SQLAlchemy encoded in them.
|
||||
|
||||
# test ibm_db_sa
|
||||
def test_flush_expired_tokens_batch_mysql(self):
|
||||
# test mysql dialect, we don't need to test IBM DB SA separately, since
|
||||
# other tests below test the differences between how they use the batch
|
||||
# strategy
|
||||
with mock.patch.object(token_sql, 'sql') as mock_sql:
|
||||
# NOTE(dstanek): this will allow us to break out of the
|
||||
# 'while True' loop
|
||||
mock_sql.get_session().query().filter().delete.return_value = 0
|
||||
|
||||
mock_sql.get_session().bind.dialect.name = 'ibm_db_sa'
|
||||
mock_sql.get_session().bind.dialect.name = 'mysql'
|
||||
tok = token_sql.Token()
|
||||
expiry_mock = mock.Mock()
|
||||
ITERS = [1, 2, 3]
|
||||
expiry_mock.return_value = iter(ITERS)
|
||||
token_sql._expiry_range_batched = expiry_mock
|
||||
tok.flush_expired_tokens()
|
||||
|
||||
mock_limit = mock_sql.get_session().query().filter().limit
|
||||
mock_limit.assert_called_with(100)
|
||||
# The expiry strategy is only invoked once, the other calls are via
|
||||
# the yield return.
|
||||
expiry_mock.assert_called_once()
|
||||
mock_delete = mock_sql.get_session().query().filter().delete
|
||||
self.assertThat(mock_delete.call_args_list,
|
||||
matchers.HasLength(len(ITERS)))
|
||||
|
||||
def test_token_flush_batch_size_default(self):
|
||||
tok = token_sql.Token()
|
||||
sqlite_batch = tok.token_flush_batch_size('sqlite')
|
||||
self.assertEqual(0, sqlite_batch)
|
||||
def test_expiry_range_batched(self):
|
||||
upper_bound_mock = mock.Mock(side_effect=[1, "final value"])
|
||||
sess_mock = mock.Mock()
|
||||
query_mock = sess_mock.query().filter().order_by().offset().limit()
|
||||
query_mock.one.side_effect = [['test'], sql.NotFound()]
|
||||
for i, x in enumerate(token_sql._expiry_range_batched(sess_mock,
|
||||
upper_bound_mock,
|
||||
batch_size=50)):
|
||||
if i == 0:
|
||||
# The first time the batch iterator returns, it should return
|
||||
# the first result that comes back from the database.
|
||||
self.assertEqual(x, 'test')
|
||||
elif i == 1:
|
||||
# The second time, the database range function should return
|
||||
# nothing, so the batch iterator returns the result of the
|
||||
# upper_bound function
|
||||
self.assertEqual(x, "final value")
|
||||
else:
|
||||
self.fail("range batch function returned more than twice")
|
||||
|
||||
def test_token_flush_batch_size_db2(self):
|
||||
def test_expiry_range_strategy_sqlite(self):
|
||||
tok = token_sql.Token()
|
||||
db2_batch = tok.token_flush_batch_size('ibm_db_sa')
|
||||
self.assertEqual(100, db2_batch)
|
||||
sqlite_strategy = tok._expiry_range_strategy('sqlite')
|
||||
self.assertEqual(token_sql._expiry_range_all, sqlite_strategy)
|
||||
|
||||
def test_expiry_range_strategy_ibm_db_sa(self):
|
||||
tok = token_sql.Token()
|
||||
db2_strategy = tok._expiry_range_strategy('ibm_db_sa')
|
||||
self.assertIsInstance(db2_strategy, functools.partial)
|
||||
self.assertEqual(db2_strategy.func, token_sql._expiry_range_batched)
|
||||
self.assertEqual(db2_strategy.keywords, {'batch_size': 100})
|
||||
|
||||
def test_expiry_range_strategy_mysql(self):
|
||||
tok = token_sql.Token()
|
||||
mysql_strategy = tok._expiry_range_strategy('mysql')
|
||||
self.assertIsInstance(mysql_strategy, functools.partial)
|
||||
self.assertEqual(mysql_strategy.func, token_sql._expiry_range_batched)
|
||||
self.assertEqual(mysql_strategy.keywords, {'batch_size': 1000})
|
||||
|
||||
|
||||
class SqlCatalog(SqlTests, test_backend.CatalogTests):
|
||||
|
@ -13,16 +13,20 @@
|
||||
# under the License.
|
||||
|
||||
import copy
|
||||
import functools
|
||||
|
||||
from keystone.common import sql
|
||||
from keystone import config
|
||||
from keystone import exception
|
||||
from keystone.i18n import _LI
|
||||
from keystone.openstack.common import log
|
||||
from keystone.openstack.common import timeutils
|
||||
from keystone import token
|
||||
from keystone.token import provider
|
||||
|
||||
|
||||
CONF = config.CONF
|
||||
LOG = log.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenModel(sql.ModelBase, sql.DictBase):
|
||||
@ -40,6 +44,42 @@ class TokenModel(sql.ModelBase, sql.DictBase):
|
||||
)
|
||||
|
||||
|
||||
def _expiry_range_batched(session, upper_bound_func, batch_size):
|
||||
"""Returns the stop point of the next batch for expiration.
|
||||
|
||||
Return the timestamp of the next token that is `batch_size` rows from
|
||||
being the oldest expired token.
|
||||
"""
|
||||
|
||||
# This expiry strategy splits the tokens into roughly equal sized batches
|
||||
# to be deleted. It does this by finding the timestamp of a token
|
||||
# `batch_size` rows from the oldest token and yielding that to the caller.
|
||||
# It's expected that the caller will then delete all rows with a timestamp
|
||||
# equal to or older than the one yielded. This may delete slightly more
|
||||
# tokens than the batch_size, but that should be ok in almost all cases.
|
||||
LOG.info(_LI('Token expiration batch size: %d') % batch_size)
|
||||
query = session.query(TokenModel.expires)
|
||||
query = query.filter(TokenModel.expires < upper_bound_func())
|
||||
query = query.order_by(TokenModel.expires)
|
||||
query = query.offset(batch_size - 1)
|
||||
query = query.limit(1)
|
||||
while True:
|
||||
try:
|
||||
next_expiration = query.one()[0]
|
||||
except sql.NotFound:
|
||||
# There are less than `batch_size` rows remaining, so fall
|
||||
# through to the normal delete
|
||||
break
|
||||
yield next_expiration
|
||||
yield upper_bound_func()
|
||||
|
||||
|
||||
def _expiry_range_all(session, upper_bound_func):
|
||||
"""Expires all tokens in one pass."""
|
||||
|
||||
yield upper_bound_func()
|
||||
|
||||
|
||||
class Token(token.persistence.Driver):
|
||||
# Public interface
|
||||
def get_token(self, token_id):
|
||||
@ -192,41 +232,45 @@ class Token(token.persistence.Driver):
|
||||
tokens.append(record)
|
||||
return tokens
|
||||
|
||||
def token_flush_batch_size(self, dialect):
|
||||
batch_size = 0
|
||||
def _expiry_range_strategy(self, dialect):
|
||||
"""Choose a token range expiration strategy
|
||||
|
||||
Based on the DB dialect, select a expiry range callable that is
|
||||
appropriate.
|
||||
"""
|
||||
|
||||
# DB2 and MySQL can both benefit from a batched strategy. On DB2 the
|
||||
# transaction log can fill up and on MySQL w/Galera, large
|
||||
# transactions can exceed the maximum write set size.
|
||||
if dialect == 'ibm_db_sa':
|
||||
# This functionality is limited to DB2, because
|
||||
# it is necessary to prevent the tranaction log
|
||||
# from filling up, whereas at least some of the
|
||||
# other supported databases do not support update
|
||||
# queries with LIMIT subqueries nor do they appear
|
||||
# to require the use of such queries when deleting
|
||||
# large numbers of records at once.
|
||||
batch_size = 100
|
||||
# Limit of 100 is known to not fill a transaction log
|
||||
# of default maximum size while not significantly
|
||||
# impacting the performance of large token purges on
|
||||
# systems where the maximum transaction log size has
|
||||
# been increased beyond the default.
|
||||
return batch_size
|
||||
return functools.partial(_expiry_range_batched,
|
||||
batch_size=100)
|
||||
elif dialect == 'mysql':
|
||||
# We want somewhat more than 100, since Galera replication delay is
|
||||
# at least RTT*2. This can be a significant amount of time if
|
||||
# doing replication across a WAN.
|
||||
return functools.partial(_expiry_range_batched,
|
||||
batch_size=1000)
|
||||
return _expiry_range_all
|
||||
|
||||
def flush_expired_tokens(self):
|
||||
session = sql.get_session()
|
||||
dialect = session.bind.dialect.name
|
||||
batch_size = self.token_flush_batch_size(dialect)
|
||||
if batch_size > 0:
|
||||
query = session.query(TokenModel.id)
|
||||
query = query.filter(TokenModel.expires < timeutils.utcnow())
|
||||
query = query.limit(batch_size).subquery()
|
||||
delete_query = (session.query(TokenModel).
|
||||
filter(TokenModel.id.in_(query)))
|
||||
while True:
|
||||
rowcount = delete_query.delete(synchronize_session=False)
|
||||
if rowcount == 0:
|
||||
break
|
||||
else:
|
||||
query = session.query(TokenModel)
|
||||
query = query.filter(TokenModel.expires < timeutils.utcnow())
|
||||
query.delete(synchronize_session=False)
|
||||
expiry_range_func = self._expiry_range_strategy(dialect)
|
||||
query = session.query(TokenModel.expires)
|
||||
total_removed = 0
|
||||
upper_bound_func = timeutils.utcnow
|
||||
for expiry_time in expiry_range_func(session, upper_bound_func):
|
||||
delete_query = query.filter(TokenModel.expires <
|
||||
expiry_time)
|
||||
row_count = delete_query.delete(synchronize_session=False)
|
||||
total_removed += row_count
|
||||
LOG.debug('Removed %d expired tokens', total_removed)
|
||||
|
||||
session.flush()
|
||||
LOG.info(_LI('Total expired tokens removed: %d'), total_removed)
|
||||
|
Loading…
Reference in New Issue
Block a user