Add support for client-side rate limiting

shade/openstacksdk has implemented client-side rate limiting on top of
keystoneauth for ages and uses it extensively in nodepool. As part of an
effort to refactor that code a new approach was devised which was much
simpler and therfore suitable for inclusion in keystoneauth directly.

The underlying goal is two-fold, but fundamentally is about allowing a
user to add some settings so that they can avoid slamming their cloud.
First, allow a user to express that they never want to exceed a given
rate. Second, allow a user to limit the number of concurrent requests
allowed to be in flight.

The settings and logic are added to Adapter and not Session so that the
settings can easily be per-service. There is no need to block requests
to nova on a neutron rate limit, after all.

Co-Authored-By: Ian Wienand <iwienand@redhat.com>
Needed-By: https://review.openstack.org/604926
Change-Id: Ic831e03a37d804f45b7ee58c87f92fa0f4411ad8
changes/43/605043/7
Monty Taylor 4 years ago
parent e878df1a16
commit 09934718f7

@ -0,0 +1,104 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import threading
import time
from six.moves import queue
class FairSemaphore(object):
"""Semaphore class that notifies in order of request.
We cannot use a normal Semaphore because it doesn't give any ordering,
which could lead to a request starving. Instead, handle them in the
order we receive them.
:param int concurrency:
How many concurrent threads can have the semaphore at once.
:param float rate_delay:
How long to wait between the start of each thread receiving the
semaphore.
"""
def __init__(self, concurrency, rate_delay):
self._lock = threading.Lock()
self._concurrency = concurrency
if concurrency:
self._count = 0
self._queue = queue.Queue()
self._rate_delay = rate_delay
self._rate_last_ts = time.time()
def __enter__(self):
"""Aquire a semaphore."""
# If concurrency is None, everyone is free to immediately execute.
if not self._concurrency:
# NOTE: Rate limiting still applies.This will ultimately impact
# concurrency a bit due to the mutex.
with self._lock:
execution_time = self._advance_timer()
else:
execution_time = self._get_ticket()
return self._wait_for_execution(execution_time)
def _wait_for_execution(self, execution_time):
"""Wait until the pre-calculated time to run."""
wait_time = execution_time - time.time()
if wait_time > 0:
time.sleep(wait_time)
def _get_ticket(self):
ticket = threading.Event()
with self._lock:
if self._count <= self._concurrency:
# We can execute, no need to wait. Take a ticket and
# move on.
self._count += 1
return self._advance_timer()
else:
# We need to wait for a ticket before we can execute.
# Put ourselves in the ticket queue to be woken up
# when available.
self._queue.put(ticket)
ticket.wait()
with self._lock:
return self._advance_timer()
def _advance_timer(self):
"""Calculate the time when it's ok to run a command again.
This runs inside of the mutex, serializing the calculation
of when it's ok to run again and setting _rate_last_ts to that
new time so that the next thread to calculate when it's safe to
run starts from the time that the current thread calculated.
"""
self._rate_last_ts = self._rate_last_ts + self._rate_delay
return self._rate_last_ts
def __exit__(self, exc_type, exc_value, traceback):
"""Release the semaphore."""
# If concurrency is None, everyone is free to immediately execute
if not self._concurrency:
return
with self._lock:
# If waiters, wake up the next item in the queue (note
# we're under the queue lock so the queue won't change
# under us).
if self._queue.qsize() > 0:
ticket = self._queue.get()
ticket.set()
else:
# Nothing else to do, give our ticket back
self._count -= 1

@ -13,6 +13,7 @@
import os
import warnings
from keystoneauth1 import _fair_semaphore
from keystoneauth1 import session
@ -92,6 +93,16 @@ class Adapter(object):
If True, requests returning failing HTTP responses will raise an
exception; if False, the response is returned. This can be
overridden on a per-request basis via the kwarg of the same name.
:param float rate_limit:
A client-side rate limit to impose on requests made through this
adapter in requests per second. For instance, a rate_limit of 2
means to allow no more than 2 requests per second, and a rate_limit
of 0.5 means to allow no more than 1 request every two seconds.
(optional, defaults to None, which means no rate limiting will be
applied).
:param int concurrency:
How many simultaneous http requests this Adapter can be used for.
(optional, defaults to None, which means no limit).
"""
client_name = None
@ -106,7 +117,9 @@ class Adapter(object):
global_request_id=None,
min_version=None, max_version=None,
default_microversion=None, status_code_retries=None,
retriable_status_codes=None, raise_exc=None):
retriable_status_codes=None, raise_exc=None,
rate_limit=None, concurrency=None,
):
if version and (min_version or max_version):
raise TypeError(
"version is mutually exclusive with min_version and"
@ -144,6 +157,15 @@ class Adapter(object):
if client_version:
self.client_version = client_version
rate_delay = 0.0
if rate_limit:
# 1 / rate converts from requests per second to delay
# between requests needed to achieve that rate.
rate_delay = 1.0 / rate_limit
self._rate_semaphore = _fair_semaphore.FairSemaphore(
concurrency, rate_delay)
def _set_endpoint_filter_kwargs(self, kwargs):
if self.service_type:
kwargs.setdefault('service_type', self.service_type)
@ -210,6 +232,8 @@ class Adapter(object):
if self.raise_exc is not None:
kwargs.setdefault('raise_exc', self.raise_exc)
kwargs.setdefault('rate_semaphore', self._rate_semaphore)
return self.session.request(url, method, **kwargs)
def get_token(self, auth=None):

@ -99,6 +99,18 @@ def _sanitize_headers(headers):
return str_dict
class NoOpSemaphore(object):
"""Empty context manager for use as a default semaphore."""
def __enter__(self):
"""Enter the context manager and do nothing."""
pass
def __exit__(self, exc_type, exc_value, traceback):
"""Exit the context manager and do nothing."""
pass
class _JSONEncoder(json.JSONEncoder):
def default(self, o):
@ -285,6 +297,9 @@ class Session(object):
:param bool collect_timing: Whether or not to collect per-method timing
information for each API call. (optional,
defaults to False)
:param rate_semaphore: Semaphore to be used to control concurrency
and rate limiting of requests. (optional,
defaults to no concurrency or rate control)
"""
user_agent = None
@ -298,7 +313,7 @@ class Session(object):
redirect=_DEFAULT_REDIRECT_LIMIT, additional_headers=None,
app_name=None, app_version=None, additional_user_agent=None,
discovery_cache=None, split_loggers=None,
collect_timing=False):
collect_timing=False, rate_semaphore=None):
self.auth = auth
self.session = _construct_session(session)
@ -320,6 +335,7 @@ class Session(object):
self._split_loggers = split_loggers
self._collect_timing = collect_timing
self._api_times = []
self._rate_semaphore = rate_semaphore or NoOpSemaphore()
if timeout is not None:
self.timeout = float(timeout)
@ -561,7 +577,7 @@ class Session(object):
allow=None, client_name=None, client_version=None,
microversion=None, microversion_service_type=None,
status_code_retries=0, retriable_status_codes=None,
**kwargs):
rate_semaphore=None, **kwargs):
"""Send an HTTP request with the specified characteristics.
Wrapper around `requests.Session.request` to handle tasks such as
@ -647,6 +663,9 @@ class Session(object):
should be retried (optional,
defaults to HTTP 503, has no effect
when status_code_retries is 0).
:param rate_semaphore: Semaphore to be used to control concurrency
and rate limiting of requests. (optional,
defaults to no concurrency or rate control)
:param kwargs: any other parameter that can be passed to
:meth:`requests.Session.request` (such as `headers`).
Except:
@ -670,6 +689,7 @@ class Session(object):
logger = logger or utils.get_logger(__name__)
# HTTP 503 - Service Unavailable
retriable_status_codes = retriable_status_codes or [503]
rate_semaphore = rate_semaphore or self._rate_semaphore
headers = kwargs.setdefault('headers', dict())
if microversion:
@ -797,7 +817,8 @@ class Session(object):
send = functools.partial(self._send_request,
url, method, redirect, log, logger,
split_loggers, connect_retries,
status_code_retries, retriable_status_codes)
status_code_retries, retriable_status_codes,
rate_semaphore)
try:
connection_params = self.get_auth_connection_params(auth=auth)
@ -885,8 +906,9 @@ class Session(object):
def _send_request(self, url, method, redirect, log, logger, split_loggers,
connect_retries, status_code_retries,
retriable_status_codes, connect_retry_delay=0.5,
status_code_retry_delay=0.5, **kwargs):
retriable_status_codes, rate_semaphore,
connect_retry_delay=0.5, status_code_retry_delay=0.5,
**kwargs):
# NOTE(jamielennox): We handle redirection manually because the
# requests lib follows some browser patterns where it will redirect
# POSTs as GETs for certain statuses which is not want we want for an
@ -900,7 +922,8 @@ class Session(object):
try:
try:
resp = self.session.request(method, url, **kwargs)
with rate_semaphore:
resp = self.session.request(method, url, **kwargs)
except requests.exceptions.SSLError as e:
msg = 'SSL exception connecting to %(url)s: %(error)s' % {
'url': url, 'error': e}
@ -934,6 +957,7 @@ class Session(object):
url, method, redirect, log, logger, split_loggers,
status_code_retries=status_code_retries,
retriable_status_codes=retriable_status_codes,
rate_semaphore=rate_semaphore,
connect_retries=connect_retries - 1,
connect_retry_delay=connect_retry_delay * 2,
**kwargs)
@ -964,6 +988,7 @@ class Session(object):
# This request actually worked so we can reset the delay count.
new_resp = self._send_request(
location, method, redirect, log, logger, split_loggers,
rate_semaphore=rate_semaphore,
connect_retries=connect_retries,
status_code_retries=status_code_retries,
retriable_status_codes=retriable_status_codes,
@ -989,6 +1014,7 @@ class Session(object):
connect_retries=connect_retries,
status_code_retries=status_code_retries - 1,
retriable_status_codes=retriable_status_codes,
rate_semaphore=rate_semaphore,
status_code_retry_delay=status_code_retry_delay * 2,
**kwargs)

@ -0,0 +1,86 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from threading import Thread
from timeit import default_timer as timer
import mock
from six.moves import queue
import testtools
from keystoneauth1 import _fair_semaphore
class SemaphoreTests(testtools.TestCase):
def _thread_worker(self):
while True:
# get returns the Item, but we don't care about the value so we
# purposely don't assign it to anything.
self.q.get()
with self.s:
self.mock_payload.do_something()
self.q.task_done()
# Have 5 threads do 10 different "things" coordinated by the fair
# semaphore.
def _concurrency_core(self, concurrency, delay):
self.s = _fair_semaphore.FairSemaphore(concurrency, delay)
self.q = queue.Queue()
for i in range(5):
t = Thread(target=self._thread_worker)
t.daemon = True
t.start()
for item in range(0, 10):
self.q.put(item)
self.q.join()
def setUp(self):
super(SemaphoreTests, self).setUp()
self.mock_payload = mock.Mock()
# We should be waiting at least 0.1s between operations, so
# the 10 operations must take at *least* 1 second
def test_semaphore_no_concurrency(self):
start = timer()
self._concurrency_core(None, 0.1)
end = timer()
self.assertTrue((end - start) > 1.0)
self.assertEqual(self.mock_payload.do_something.call_count, 10)
def test_semaphore_single_concurrency(self):
start = timer()
self._concurrency_core(1, 0.1)
end = timer()
self.assertTrue((end - start) > 1.0)
self.assertEqual(self.mock_payload.do_something.call_count, 10)
def test_semaphore_multiple_concurrency(self):
start = timer()
self._concurrency_core(5, 0.1)
end = timer()
self.assertTrue((end - start) > 1.0)
self.assertEqual(self.mock_payload.do_something.call_count, 10)
# do some high speed tests; I don't think we can really assert
# much about these other than they don't deadlock...
def test_semaphore_fast_no_concurrency(self):
self._concurrency_core(None, 0.0)
def test_semaphore_fast_single_concurrency(self):
self._concurrency_core(1, 0.0)
def test_semaphore_fast_multiple_concurrency(self):
self._concurrency_core(5, 0.0)

@ -1565,6 +1565,7 @@ class AdapterTest(utils.TestCase):
with mock.patch.object(sess, 'request') as m:
adapter.Adapter(sess, **adap_kwargs).get(url, **get_kwargs)
m.assert_called_once_with(url, 'GET', endpoint_filter={},
rate_semaphore=mock.ANY,
**exp_kwargs)
# No default_microversion in Adapter, no microversion in get()
@ -1588,6 +1589,7 @@ class AdapterTest(utils.TestCase):
with mock.patch.object(sess, 'request') as m:
adapter.Adapter(sess, **adap_kwargs).get(url, **get_kwargs)
m.assert_called_once_with(url, 'GET', endpoint_filter={},
rate_semaphore=mock.ANY,
**exp_kwargs)
# No raise_exc in Adapter or get()

@ -0,0 +1,10 @@
---
features:
- |
Support added for client-side rate limiting. Two new parameters now
exist for ``keystoneauth1.adapter.Adapter``. ``rate`` expresses a
maximum rate at which to execute requests. ``parallel_limit`` allows
for the creation of a semaphore to control the maximum number of
requests that can be active at any one given point in time.
Both default to ``None`` which has the normal behavior or not limiting
requests in any manner.
Loading…
Cancel
Save