Add typing

Change-Id: Ib00059b676a8fce50a8390b13b52ff5aa5805739
Signed-off-by: Stephen Finucane <stephenfin@redhat.com>
This commit is contained in:
Stephen Finucane
2026-03-08 16:12:08 +00:00
parent bd377bc3c2
commit fcf603253a
34 changed files with 706 additions and 416 deletions
+40 -15
View File
@@ -19,11 +19,28 @@ import hmac
import json
import os
import uuid
from collections.abc import Generator, Sequence
from typing import Any, TypeVar, overload
from oslo_utils import uuidutils
_C = TypeVar("_C")
_T = TypeVar("_T")
def split(text, strip=True):
_AlreadySplit = TypeVar('_AlreadySplit', bound=list[Any] | tuple[Any, ...])
@overload
def split(text: str, strip: bool = True) -> list[str]: ...
@overload
def split(text: _AlreadySplit, strip: bool = True) -> _AlreadySplit: ...
def split(
text: str | _AlreadySplit, strip: bool = True
) -> list[str] | _AlreadySplit:
"""Splits a comma separated text blob into its components.
Does nothing if already a list or tuple.
@@ -38,7 +55,7 @@ def split(text, strip=True):
return text.split(",")
def binary_encode(text, encoding="utf-8"):
def binary_encode(text: bytes | str, encoding: str = "utf-8") -> bytes:
"""Converts a string of into a binary type using given encoding.
Does nothing if text not unicode string.
@@ -51,7 +68,7 @@ def binary_encode(text, encoding="utf-8"):
raise TypeError("Expected binary or string type")
def binary_decode(data, encoding="utf-8"):
def binary_decode(data: bytes | str, encoding: str = "utf-8") -> str:
"""Converts a binary type into a text type using given encoding.
Does nothing if data is already unicode string.
@@ -64,14 +81,16 @@ def binary_decode(data, encoding="utf-8"):
raise TypeError("Expected binary or string type")
def generate_hmac(data, hmac_key):
def generate_hmac(data: bytes | str, hmac_key: bytes | str) -> str:
"""Generate a hmac using a known key given the provided content."""
h = hmac.new(binary_encode(hmac_key), digestmod=hashlib.sha1)
h.update(binary_encode(data))
return h.hexdigest()
def signed_pack(data, hmac_key):
def signed_pack(
data: dict[str, str], hmac_key: str | None
) -> tuple[bytes, str | None]:
"""Pack and sign data with hmac_key."""
raw_data = base64.urlsafe_b64encode(binary_encode(json.dumps(data)))
@@ -82,7 +101,11 @@ def signed_pack(data, hmac_key):
return raw_data, generate_hmac(raw_data, hmac_key) if hmac_key else None
def signed_unpack(data, hmac_data, hmac_keys):
def signed_unpack(
data: str | bytes | None,
hmac_data: str | None,
hmac_keys: Sequence[str] | None,
) -> dict[str, Any] | None:
"""Unpack data and check that it was signed with hmac_key.
:param data: json string that was singed_packed.
@@ -95,7 +118,7 @@ def signed_unpack(data, hmac_data, hmac_keys):
"""
# NOTE(boris-42): For security reason, if there is no hmac_data or
# hmac_keys we don't trust data => return None.
if not (hmac_keys and hmac_data):
if not hmac_keys or not hmac_data or not data:
return None
hmac_data = hmac_data.strip()
if not hmac_data:
@@ -108,7 +131,7 @@ def signed_unpack(data, hmac_data, hmac_keys):
else:
if hmac.compare_digest(hmac_data, user_hmac_data):
try:
contents = json.loads(
contents: dict[str, Any] = json.loads(
binary_decode(base64.urlsafe_b64decode(data))
)
contents["hmac_key"] = hmac_key
@@ -118,14 +141,16 @@ def signed_unpack(data, hmac_data, hmac_keys):
return None
def itersubclasses(cls, _seen=None):
def itersubclasses(
cls: type[_C], _seen: set[type] | None = None
) -> Generator[type[_C], None, None]:
"""Generator over all subclasses of a given class in depth first order."""
_seen = _seen or set()
try:
subs = cls.__subclasses__()
except TypeError: # fails only when cls is type
subs = cls.__subclasses__(cls)
subs = cls.__subclasses__(cls) # type: ignore[call-arg]
for sub in subs:
if sub not in _seen:
_seen.add(sub)
@@ -134,13 +159,13 @@ def itersubclasses(cls, _seen=None):
yield sub
def import_modules_from_package(package):
def import_modules_from_package(package: str) -> None:
"""Import modules from package and append into sys.modules
:param: package - Full package name. For example: rally.deploy.engines
"""
path = [os.path.dirname(__file__), ".."] + package.split(".")
path = os.path.join(*path)
path_parts = [os.path.dirname(__file__), ".."] + package.split(".")
path = os.path.join(*path_parts)
for root, dirs, files in os.walk(path):
for filename in files:
if filename.startswith("__") or not filename.endswith(".py"):
@@ -150,7 +175,7 @@ def import_modules_from_package(package):
__import__(module_name)
def shorten_id(span_id):
def shorten_id(span_id: str | int) -> int:
"""Convert from uuid4 to 64 bit id for OpenTracing"""
int64_max = (1 << 64) - 1
if isinstance(span_id, int):
@@ -163,7 +188,7 @@ def shorten_id(span_id):
return short_id
def uuid_to_int128(span_uuid):
def uuid_to_int128(span_uuid: str | int) -> int:
"""Convert from uuid4 to 128 bit id for OpenTracing"""
if isinstance(span_uuid, int):
return span_uuid
+11 -7
View File
@@ -13,10 +13,12 @@
# License for the specific language governing permissions and limitations
# under the License.
from collections.abc import Callable
import os
from typing import Any
def env(*args, **kwargs):
def env(*args: str, **kwargs: str) -> str:
"""Returns the first environment variable set.
If all are empty, defaults to '' or keyword arg `default`.
@@ -28,7 +30,9 @@ def env(*args, **kwargs):
return kwargs.get("default", "")
def arg(*args, **kwargs):
def arg(
*args: Any, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Decorator for CLI args.
Example:
@@ -38,22 +42,22 @@ def arg(*args, **kwargs):
... pass
"""
def _decorator(func):
def _decorator(func: Callable[..., Any]) -> Callable[..., Any]:
add_arg(func, *args, **kwargs)
return func
return _decorator
def add_arg(func, *args, **kwargs):
def add_arg(func: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
"""Bind CLI arguments to a shell.py `do_foo` function."""
if not hasattr(func, "arguments"):
func.arguments = []
setattr(func, "arguments", [])
# NOTE(sirp): avoid dups that can occur when the module is shared across
# tests.
if (args, kwargs) not in func.arguments:
if (args, kwargs) not in getattr(func, "arguments"):
# Because of the semantics of decorator composition if we just append
# to the options list positional options will appear to be backwards.
func.arguments.insert(0, (args, kwargs))
getattr(func, "arguments").insert(0, (args, kwargs))
+13 -11
View File
@@ -13,8 +13,10 @@
# License for the specific language governing permissions and limitations
# under the License.
import argparse
import json
import os
from typing import Any
from oslo_utils import encodeutils
from oslo_utils import uuidutils
@@ -26,7 +28,7 @@ from osprofiler import exc
class BaseCommand:
group_name = None
group_name: str | None = None
class TraceCommands(BaseCommand):
@@ -84,7 +86,7 @@ class TraceCommands(BaseCommand):
help="filename for rendering the dot graph in pdf format",
)
@cliutils.arg("--out", dest="file_name", help="save output in file")
def show(self, args):
def show(self, args: argparse.Namespace) -> None:
"""Display trace results in HTML, JSON or DOT format."""
if not args.conn_str:
@@ -102,7 +104,7 @@ class TraceCommands(BaseCommand):
try:
engine = base.get_driver(args.conn_str, **args.__dict__)
except Exception as e:
raise exc.CommandError(e.message)
raise exc.CommandError(str(e))
trace = engine.get_report(args.trace)
@@ -115,7 +117,7 @@ class TraceCommands(BaseCommand):
# Since datetime.datetime is not JSON serializable by default,
# this method will handle that.
def datetime_json_serialize(obj):
def datetime_json_serialize(obj: Any) -> Any:
if hasattr(obj, "isoformat"):
return obj.isoformat()
else:
@@ -162,9 +164,9 @@ class TraceCommands(BaseCommand):
else:
print(output)
def _create_dot_graph(self, trace):
def _create_dot_graph(self, trace: dict[str, Any]) -> Any:
try:
import graphviz
import graphviz # type: ignore[import-not-found]
except ImportError:
raise exc.CommandError(
"graphviz library is required to use this option."
@@ -173,7 +175,7 @@ class TraceCommands(BaseCommand):
dot = graphviz.Digraph(format="pdf")
next_id = [0]
def _create_node(info):
def _create_node(info: dict[str, Any]) -> str:
time_taken = info["finished"] - info["started"]
service = info["service"] + ":" if "service" in info else ""
name = info["name"]
@@ -194,7 +196,7 @@ class TraceCommands(BaseCommand):
dot.node(node_id, label)
return node_id
def _create_sub_graph(root):
def _create_sub_graph(root: dict[str, Any]) -> str:
rid = _create_node(root["info"])
for child in root["children"]:
cid = _create_sub_graph(child)
@@ -218,7 +220,7 @@ class TraceCommands(BaseCommand):
default=False,
help="List all traces that contain error.",
)
def list(self, args):
def list(self, args: argparse.Namespace) -> None:
"""List all traces"""
if not args.conn_str:
raise exc.CommandError(
@@ -229,13 +231,13 @@ class TraceCommands(BaseCommand):
try:
engine = base.get_driver(args.conn_str, **args.__dict__)
except Exception as e:
raise exc.CommandError(e.message)
raise exc.CommandError(str(e))
fields = ("base_id", "timestamp")
pretty_table = prettytable.PrettyTable(fields)
pretty_table.align = "l"
if not args.error_trace:
traces = engine.list_traces(fields)
traces = engine.list_traces(set(fields))
else:
traces = engine.list_error_traces()
for trace in traces:
+11 -5
View File
@@ -21,6 +21,7 @@ Command-line interface to the OpenStack Profiler.
import argparse
import inspect
import sys
from typing import Any
from oslo_config import cfg
@@ -31,13 +32,13 @@ from osprofiler import opts
class OSProfilerShell:
def __init__(self, argv):
def __init__(self, argv: list[str]) -> None:
args = self._get_base_parser().parse_args(argv)
opts.set_defaults(cfg.CONF)
args.func(args)
def _get_base_parser(self):
def _get_base_parser(self) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog="osprofiler", description=__doc__.strip(), add_help=True
)
@@ -50,9 +51,13 @@ class OSProfilerShell:
return parser
def _append_subcommands(self, parent_parser):
def _append_subcommands(
self, parent_parser: argparse.ArgumentParser
) -> None:
subcommands = parent_parser.add_subparsers(help="<subcommands>")
for group_cls in commands.BaseCommand.__subclasses__():
if group_cls.group_name is None:
continue
group_parser = subcommands.add_parser(group_cls.group_name)
subcommand_parser = group_parser.add_subparsers()
@@ -71,7 +76,7 @@ class OSProfilerShell:
command_parser.add_argument(*args, **kwargs)
command_parser.set_defaults(func=callback)
def _no_project_and_domain_set(self, args):
def _no_project_and_domain_set(self, args: Any) -> bool:
if not (
args.os_project_id
or (
@@ -85,7 +90,7 @@ class OSProfilerShell:
return False
def main(args=None):
def main(args: list[str] | None = None) -> int | None:
if args is None:
args = sys.argv[1:]
@@ -94,6 +99,7 @@ def main(args=None):
except exc.CommandError as e:
print(e.message)
return 1
return None
if __name__ == "__main__":
+46 -40
View File
@@ -15,6 +15,7 @@
import datetime
import logging
from typing import Any
from urllib import parse as urlparse
from osprofiler import _utils
@@ -22,7 +23,7 @@ from osprofiler import _utils
LOG = logging.getLogger(__name__)
def get_driver(connection_string, *args, **kwargs):
def get_driver(connection_string: str, *args: Any, **kwargs: Any) -> "Driver":
"""Create driver's instance according to specified connection string"""
# NOTE(ayelistratov) Backward compatibility with old Messaging notation
# Remove after patching all OS services
@@ -66,20 +67,25 @@ class Driver:
and implemented by any class derived from this class.
"""
default_trace_fields = {"base_id", "timestamp"}
default_trace_fields: set[str] = {"base_id", "timestamp"}
def __init__(
self, connection_str, project=None, service=None, host=None, **kwargs
):
self,
connection_str: str,
project: str | None = None,
service: str | None = None,
host: str | None = None,
**kwargs: Any,
) -> None:
self.connection_str = connection_str
self.project = project
self.service = service
self.host = host
self.result = {}
self.started_at = None
self.finished_at = None
self.result: dict[str, Any] = {}
self.started_at: datetime.datetime | None = None
self.finished_at: datetime.datetime | None = None
# Last trace started time
self.last_started_at = None
self.last_started_at: datetime.datetime | None = None
profiler_config = kwargs.get("conf", {}).get("profiler", {})
if hasattr(profiler_config, "filter_error_trace"):
@@ -87,7 +93,7 @@ class Driver:
else:
self.filter_error_trace = False
def notify(self, info, **kwargs):
def notify(self, info: dict[str, Any], **kwargs: Any) -> None:
"""This method will be called on each notifier.notify() call.
To add new drivers you should, create new subclass of this class and
@@ -108,7 +114,7 @@ class Driver:
"or has to be overridden"
)
def get_report(self, base_id):
def get_report(self, base_id: str) -> dict[str, Any]:
"""Forms and returns report composed from the stored notifications.
:param base_id: Base id of trace elements.
@@ -119,11 +125,13 @@ class Driver:
)
@classmethod
def get_name(cls):
def get_name(cls) -> str:
"""Returns backend specific name for the driver."""
return cls.__name__
def list_traces(self, fields=None):
def list_traces(
self, fields: set[str] | None = None
) -> list[dict[str, Any]]:
"""Query all traces from the storage.
:param fields: Set of trace fields to return. Defaults to 'base_id'
@@ -136,7 +144,7 @@ class Driver:
"or has to be overridden"
)
def list_error_traces(self):
def list_error_traces(self) -> list[dict[str, Any]]:
"""Query all error traces from the storage.
:return List of traces, where each trace is a dictionary containing
@@ -148,7 +156,7 @@ class Driver:
)
@staticmethod
def _build_tree(nodes):
def _build_tree(nodes: dict[str, Any]) -> list[dict[str, Any]]:
"""Builds the tree (forest) data structure based on the list of nodes.
Tree building works in O(n*log(n)).
@@ -161,7 +169,7 @@ class Driver:
empty for leafs)
"""
tree = []
tree: list[dict[str, Any]] = []
for trace_id in nodes:
node = nodes[trace_id]
@@ -182,15 +190,15 @@ class Driver:
def _append_results(
self,
trace_id,
parent_id,
name,
project,
service,
host,
timestamp,
raw_payload=None,
):
trace_id: str,
parent_id: str,
name: str,
project: str | None,
service: str | None,
host: str | None,
timestamp: str,
raw_payload: dict[str, Any] | None = None,
) -> None:
"""Appends the notification to the dictionary of notifications.
:param trace_id: UUID of current trace point
@@ -204,9 +212,7 @@ class Driver:
:param raw_payload: raw notification without any filtering, with all
fields included
"""
timestamp = datetime.datetime.strptime(
timestamp, "%Y-%m-%dT%H:%M:%S.%f"
)
ts = datetime.datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f")
if trace_id not in self.result:
self.result[trace_id] = {
"info": {
@@ -222,29 +228,29 @@ class Driver:
self.result[trace_id]["info"][f"meta.raw_payload.{name}"] = raw_payload
if name.endswith("stop"):
self.result[trace_id]["info"]["finished"] = timestamp
self.result[trace_id]["info"]["finished"] = ts
self.result[trace_id]["info"]["exception"] = "None"
if raw_payload and "info" in raw_payload:
exc = raw_payload["info"].get("etype", "None")
self.result[trace_id]["info"]["exception"] = exc
else:
self.result[trace_id]["info"]["started"] = timestamp
if not self.last_started_at or self.last_started_at < timestamp:
self.last_started_at = timestamp
self.result[trace_id]["info"]["started"] = ts
if not self.last_started_at or self.last_started_at < ts:
self.last_started_at = ts
if not self.started_at or self.started_at > timestamp:
self.started_at = timestamp
if not self.started_at or self.started_at > ts:
self.started_at = ts
if not self.finished_at or self.finished_at < timestamp:
self.finished_at = timestamp
if not self.finished_at or self.finished_at < ts:
self.finished_at = ts
def _parse_results(self):
def _parse_results(self) -> dict[str, Any]:
"""Parses Driver's notifications placed by _append_results() .
:returns: full profiling report
"""
def msec(dt):
def msec(dt: datetime.timedelta) -> int:
# NOTE(boris-42): Unfortunately this is the simplest way that works
# in py26 and py27
microsec = (
@@ -252,7 +258,7 @@ class Driver:
)
return int(microsec / 1000.0)
stats = {}
stats: dict[str, Any] = {}
for r in self.result.values():
# NOTE(boris-42): We are not able to guarantee that the backend
@@ -282,12 +288,12 @@ class Driver:
"name": "total",
"started": 0,
"finished": msec(self.finished_at - self.started_at)
if self.started_at
if self.started_at and self.finished_at
else None,
"last_trace_started": msec(
self.last_started_at - self.started_at
)
if self.started_at
if self.started_at and self.last_started_at
else None,
},
"children": self._build_tree(self.result),
+25 -22
View File
@@ -13,6 +13,7 @@
# License for the specific language governing permissions and limitations
# under the License.
from typing import Any
from urllib import parse as parser
from oslo_config import cfg
@@ -24,14 +25,14 @@ from osprofiler import exc
class ElasticsearchDriver(base.Driver):
def __init__(
self,
connection_str,
index_name="osprofiler-notifications",
project=None,
service=None,
host=None,
conf=cfg.CONF,
**kwargs,
):
connection_str: str,
index_name: str = "osprofiler-notifications",
project: str | None = None,
service: str | None = None,
host: str | None = None,
conf: cfg.ConfigOpts = cfg.CONF,
**kwargs: Any,
) -> None:
"""Elasticsearch driver for OSProfiler."""
super().__init__(
@@ -60,10 +61,10 @@ class ElasticsearchDriver(base.Driver):
self.index_name_error = "osprofiler-notifications-error"
@classmethod
def get_name(cls):
def get_name(cls) -> str:
return "elasticsearch"
def notify(self, info):
def notify(self, info: dict[str, Any], **kwargs: Any) -> None:
"""Send notifications to Elasticsearch.
:param info: Contains information about trace element.
@@ -80,7 +81,7 @@ class ElasticsearchDriver(base.Driver):
info = info.copy()
info["project"] = self.project
info["service"] = self.service
self.client.index(
self.client.index( # type: ignore[call-arg]
index=self.index_name,
doc_type=self.conf.profiler.es_doc_type,
body=info,
@@ -92,22 +93,22 @@ class ElasticsearchDriver(base.Driver):
):
self.notify_error_trace(info)
def notify_error_trace(self, info):
def notify_error_trace(self, info: dict[str, Any]) -> None:
"""Store base_id and timestamp of error trace to a separate index."""
self.client.index(
self.client.index( # type: ignore[call-arg]
index=self.index_name_error,
doc_type=self.conf.profiler.es_doc_type,
body={"base_id": info["base_id"], "timestamp": info["timestamp"]},
)
def _hits(self, response):
def _hits(self, response: Any) -> list[Any]:
"""Returns all hits of search query using scrolling
:param response: ElasticSearch query response
"""
scroll_id = response["_scroll_id"]
scroll_size = len(response["hits"]["hits"])
result = []
result: list[Any] = []
while scroll_size > 0:
for hit in response["hits"]["hits"]:
@@ -120,7 +121,9 @@ class ElasticsearchDriver(base.Driver):
return result
def list_traces(self, fields=None):
def list_traces(
self, fields: set[str] | None = None
) -> list[dict[str, Any]]:
"""Query all traces from the storage.
:param fields: Set of trace fields to return. Defaults to 'base_id'
@@ -128,10 +131,10 @@ class ElasticsearchDriver(base.Driver):
:returns: List of traces, where each trace is a dictionary containing
at least `base_id` and `timestamp`.
"""
query = {"match_all": {}}
query: dict[str, Any] = {"match_all": {}}
fields = set(fields or self.default_trace_fields)
response = self.client.search(
response = self.client.search( # type: ignore[call-arg]
index=self.index_name,
doc_type=self.conf.profiler.es_doc_type,
size=self.conf.profiler.es_scroll_size,
@@ -145,9 +148,9 @@ class ElasticsearchDriver(base.Driver):
return self._hits(response)
def list_error_traces(self):
def list_error_traces(self) -> list[dict[str, Any]]:
"""Returns all traces that have error/exception."""
response = self.client.search(
response = self.client.search( # type: ignore[call-arg]
index=self.index_name_error,
doc_type=self.conf.profiler.es_doc_type,
size=self.conf.profiler.es_scroll_size,
@@ -161,12 +164,12 @@ class ElasticsearchDriver(base.Driver):
return self._hits(response)
def get_report(self, base_id):
def get_report(self, base_id: str) -> dict[str, Any]:
"""Retrieves and parses notification from Elasticsearch.
:param base_id: Base id of trace elements.
"""
response = self.client.search(
response = self.client.search( # type: ignore[call-arg]
index=self.index_name,
doc_type=self.conf.profiler.es_doc_type,
size=self.conf.profiler.es_scroll_size,
+12 -8
View File
@@ -13,6 +13,10 @@
# License for the specific language governing permissions and limitations
# under the License.
from typing import Any
from oslo_config import cfg
from osprofiler.drivers import base
from osprofiler import exc
@@ -21,17 +25,17 @@ from osprofiler import exc
class Jaeger(base.Driver):
def __init__(
self,
connection_str,
project=None,
service=None,
host=None,
conf=None,
**kwargs,
):
connection_str: str,
project: str | None = None,
service: str | None = None,
host: str | None = None,
conf: cfg.ConfigOpts | None = None,
**kwargs: Any,
) -> None:
"""Jaeger driver for OSProfiler."""
raise exc.CommandError('Jaeger driver is no longer supported')
@classmethod
def get_name(cls):
def get_name(cls) -> str:
return "jaeger"
+43 -30
View File
@@ -19,6 +19,7 @@ Classes to use VMware vRealize Log Insight as the trace data store.
import json
import logging as log
from typing import Any
from urllib import parse as urlparse
import netaddr
@@ -48,8 +49,13 @@ class LogInsightDriver(base.Driver):
"""
def __init__(
self, connection_str, project=None, service=None, host=None, **kwargs
):
self,
connection_str: str,
project: str | None = None,
service: str | None = None,
host: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(
connection_str, project=project, service=service, host=host
)
@@ -73,19 +79,19 @@ class LogInsightDriver(base.Driver):
self._client.login()
@classmethod
def get_name(cls):
def get_name(cls) -> str:
return "loginsight"
def notify(self, info):
def notify(self, info: dict[str, Any], **kwargs: Any) -> None:
"""Send trace to Log Insight server."""
trace = info.copy()
trace["project"] = self.project
trace["service"] = self.service
event = {"text": "OSProfiler trace"}
event: dict[str, Any] = {"text": "OSProfiler trace"}
def _create_field(name, content):
def _create_field(name: str, content: Any) -> dict[str, Any]:
return {"name": name, "content": content}
event["fields"] = [
@@ -99,7 +105,7 @@ class LogInsightDriver(base.Driver):
self._client.send_event(event)
def get_report(self, base_id):
def get_report(self, base_id: str) -> dict[str, Any]:
"""Retrieves and parses trace data from Log Insight.
:param base_id: Trace base ID
@@ -150,13 +156,13 @@ class LogInsightClient:
def __init__(
self,
host,
username,
password,
api_port=9000,
api_ssl_port=9543,
query_timeout=60000,
):
host: str,
username: str,
password: str,
api_port: int = 9000,
api_ssl_port: int = 9543,
query_timeout: int = 60000,
) -> None:
self._host = host
self._username = username
self._password = password
@@ -164,9 +170,9 @@ class LogInsightClient:
self._api_ssl_port = api_ssl_port
self._query_timeout = query_timeout
self._session = requests.Session()
self._session_id = None
self._session_id: str | None = None
def _build_base_url(self, scheme):
def _build_base_url(self, scheme: str) -> str:
proto_str = f"{scheme}://"
host_str = (
f"[{self._host}]" if netaddr.valid_ipv6(self._host) else self._host
@@ -176,7 +182,7 @@ class LogInsightClient:
)
return proto_str + host_str + port_str
def _check_response(self, resp):
def _check_response(self, resp: requests.Response) -> None:
if resp.status_code == 440:
raise exc.LogInsightLoginTimeout()
@@ -189,12 +195,18 @@ class LogInsightClient:
except ValueError:
pass
else:
msg = resp.reason
msg = resp.reason or msg
raise exc.LogInsightAPIError(msg)
def _send_request(
self, method, scheme, path, headers=None, body=None, params=None
):
self,
method: str,
scheme: str,
path: str,
headers: dict[str, str] | None = None,
body: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
) -> Any:
url = f"{self._build_base_url(scheme)}/{path}"
headers = headers or {}
@@ -205,20 +217,21 @@ class LogInsightClient:
req = requests.Request(
method, url, headers=headers, data=json.dumps(body), params=params
)
req = req.prepare()
resp = self._session.send(req, verify=False)
prepped = req.prepare()
resp = self._session.send(prepped, verify=False)
self._check_response(resp)
return resp.json()
def _get_auth_header(self):
return {"X-LI-Session-Id": self._session_id}
def _get_auth_header(self) -> dict[str, str]:
return {"X-LI-Session-Id": self._session_id or ""}
def _trunc_session_id(self):
def _trunc_session_id(self) -> str | None:
if self._session_id:
return self._session_id[-5:]
return None
def _is_current_session_active(self):
def _is_current_session_active(self) -> bool:
try:
self._send_request(
"get",
@@ -237,7 +250,7 @@ class LogInsightClient:
return False
@synchronized("li_login_lock")
def login(self):
def login(self) -> None:
# Another thread might have created the session while the current
# thread was waiting for the lock.
if self._session_id and self._is_current_session_active():
@@ -254,13 +267,13 @@ class LogInsightClient:
self._session_id = resp["sessionId"]
LOG.debug("Established session %s.", self._trunc_session_id())
def send_event(self, event):
def send_event(self, event: dict[str, Any]) -> None:
events = {"events": [event]}
self._send_request(
"post", "http", self.EVENTS_INGEST_PATH, body=events
)
def query_events(self, params):
def query_events(self, params: dict[str, str]) -> Any:
# Assumes that the keys and values in the params are strings and
# the operator is "CONTAINS".
constraints = []
@@ -272,7 +285,7 @@ class LogInsightClient:
self.QUERY_EVENTS_BASE_PATH, "/".join(constraints)
)
def _query_events():
def _query_events() -> Any:
return self._send_request(
"get",
"https",
+32 -21
View File
@@ -16,7 +16,9 @@
import functools
import signal
import time
from typing import Any
from oslo_config import cfg
from oslo_utils import importutils
from osprofiler.drivers import base
@@ -25,16 +27,16 @@ from osprofiler.drivers import base
class Messaging(base.Driver):
def __init__(
self,
connection_str,
project=None,
service=None,
host=None,
context=None,
conf=None,
transport_url=None,
idle_timeout=1,
**kwargs,
):
connection_str: str,
project: str | None = None,
service: str | None = None,
host: str | None = None,
context: Any = None,
conf: cfg.ConfigOpts | None = None,
transport_url: str | None = None,
idle_timeout: int = 1,
**kwargs: Any,
) -> None:
"""Driver that uses messaging as transport for notifications
:param connection_str: OSProfiler driver connection string,
@@ -73,7 +75,7 @@ class Messaging(base.Driver):
)
conf = oslo_config.cfg.CONF
transport_kwargs = {}
transport_kwargs: dict[str, Any] = {}
if transport_url:
transport_kwargs["url"] = transport_url
@@ -91,10 +93,12 @@ class Messaging(base.Driver):
self.idle_timeout = idle_timeout
@classmethod
def get_name(cls):
def get_name(cls) -> str:
return "messaging"
def notify(self, info, context=None):
def notify(
self, info: dict[str, Any], context: Any = None, **kwargs: Any
) -> None:
"""Send notifications to backend via oslo.messaging notifier API.
:param info: Contains information about trace element.
@@ -118,7 +122,7 @@ class Messaging(base.Driver):
info,
)
def get_report(self, base_id):
def get_report(self, base_id: str) -> dict[str, Any]:
notification_endpoint = NotifyEndpoint(self.oslo_messaging, base_id)
endpoints = [notification_endpoint]
targets = [self.oslo_messaging.Target(topic="profiler")]
@@ -126,7 +130,7 @@ class Messaging(base.Driver):
self.transport, targets, endpoints, executor="threading"
)
state = dict(running=False)
state: dict[str, bool] = dict(running=False)
sfn = functools.partial(signal_handler, state=state)
# modify signal handlers to handle interruption gracefully
@@ -194,21 +198,28 @@ class Messaging(base.Driver):
class NotifyEndpoint:
def __init__(self, oslo_messaging, base_id):
self.received_messages = []
def __init__(self, oslo_messaging: Any, base_id: str) -> None:
self.received_messages: list[Any] = []
self.last_read_time = time.time()
self.filter_rule = oslo_messaging.NotificationFilter(
payload={"base_id": base_id}
)
def info(self, ctxt, publisher_id, event_type, payload, metadata):
def info(
self,
ctxt: Any,
publisher_id: Any,
event_type: Any,
payload: Any,
metadata: Any,
) -> None:
self.received_messages.append(payload)
self.last_read_time = time.time()
def get_messages(self):
def get_messages(self) -> list[Any]:
return self.received_messages
def get_last_read_time(self):
def get_last_read_time(self) -> float:
return self.last_read_time # time when the latest event was received
@@ -216,6 +227,6 @@ class SignalExit(BaseException):
pass
def signal_handler(signum, frame, state):
def signal_handler(signum: Any, frame: Any, state: dict[str, bool]) -> None:
state["running"] = False
raise SignalExit()
+20 -16
View File
@@ -13,6 +13,8 @@
# License for the specific language governing permissions and limitations
# under the License.
from typing import Any
from osprofiler.drivers import base
from osprofiler import exc
@@ -20,13 +22,13 @@ from osprofiler import exc
class MongoDB(base.Driver):
def __init__(
self,
connection_str,
db_name="osprofiler",
project=None,
service=None,
host=None,
**kwargs,
):
connection_str: str,
db_name: str = "osprofiler",
project: str | None = None,
service: str | None = None,
host: str | None = None,
**kwargs: Any,
) -> None:
"""MongoDB driver for OSProfiler."""
super().__init__(
@@ -45,14 +47,14 @@ class MongoDB(base.Driver):
"To install with pip:\n `pip install pymongo`."
)
client = MongoClient(self.connection_str, connect=False)
client: Any = MongoClient(self.connection_str, connect=False)
self.db = client[db_name]
@classmethod
def get_name(cls):
def get_name(cls) -> str:
return "mongodb"
def notify(self, info):
def notify(self, info: dict[str, Any], **kwargs: Any) -> None:
"""Send notifications to MongoDB.
:param info: Contains information about trace element.
@@ -76,7 +78,7 @@ class MongoDB(base.Driver):
):
self.notify_error_trace(data)
def notify_error_trace(self, data):
def notify_error_trace(self, data: dict[str, Any]) -> None:
"""Store base_id and timestamp of error trace to a separate db."""
self.db.profiler_error.update(
{"base_id": data["base_id"]},
@@ -84,7 +86,9 @@ class MongoDB(base.Driver):
upsert=True,
)
def list_traces(self, fields=None):
def list_traces(
self, fields: set[str] | None = None
) -> list[dict[str, Any]]:
"""Query all traces from the storage.
:param fields: Set of trace fields to return. Defaults to 'base_id'
@@ -94,7 +98,7 @@ class MongoDB(base.Driver):
"""
fields = set(fields or self.default_trace_fields)
ids = self.db.profiler.find({}).distinct("base_id")
out_format = {"base_id": 1, "timestamp": 1, "_id": 0}
out_format: dict[str, int] = {"base_id": 1, "timestamp": 1, "_id": 0}
out_format.update({i: 1 for i in fields})
return [
self.db.profiler.find({"base_id": i}, out_format).sort(
@@ -103,12 +107,12 @@ class MongoDB(base.Driver):
for i in ids
]
def list_error_traces(self):
def list_error_traces(self) -> list[dict[str, Any]]:
"""Returns all traces that have error/exception."""
out_format = {"base_id": 1, "timestamp": 1, "_id": 0}
return self.db.profiler_error.find({}, out_format)
return list(self.db.profiler_error.find({}, out_format))
def get_report(self, base_id):
def get_report(self, base_id: str) -> dict[str, Any]:
"""Retrieves and parses notification from MongoDB.
:param base_id: Base id of trace elements.
+31 -25
View File
@@ -13,6 +13,7 @@
# under the License.
import collections
from typing import Any
from urllib import parse as parser
from oslo_config import cfg
@@ -26,13 +27,13 @@ from osprofiler import exc
class OTLP(base.Driver):
def __init__(
self,
connection_str,
project=None,
service=None,
host=None,
conf=cfg.CONF,
**kwargs,
):
connection_str: str,
project: str | None = None,
service: str | None = None,
host: str | None = None,
conf: cfg.ConfigOpts = cfg.CONF,
**kwargs: Any,
) -> None:
"""OTLP driver using OTLP exporters."""
super().__init__(
@@ -74,30 +75,32 @@ class OTLP(base.Driver):
self.tracer = self.trace_api.get_tracer(__name__)
exporter = OTLPSpanExporter(f"{parsed_url.geturl()}/v1/traces")
self.trace_api.get_tracer_provider().add_span_processor(
self.trace_api.get_tracer_provider().add_span_processor( # type: ignore[attr-defined]
BatchSpanProcessor(exporter)
)
self.spans = collections.deque()
self.spans: collections.deque[Any] = collections.deque()
def _get_service_name(self, conf, project, service):
def _get_service_name(
self, conf: cfg.ConfigOpts, project: str | None, service: str | None
) -> str:
prefix = conf.profiler_otlp.service_name_prefix
if prefix:
return f"{prefix}-{project}-{service}"
return f"{project}-{service}"
@classmethod
def get_name(cls):
def get_name(cls) -> str:
return "otlp"
def _kind(self, name):
def _kind(self, name: str) -> Any:
if "wsgi" in name:
return self.trace_api.SpanKind.SERVER
elif "db" in name or "http" in name or "api" in name:
return self.trace_api.SpanKind.CLIENT
return self.trace_api.SpanKind.INTERNAL
def _name(self, payload):
def _name(self, payload: dict[str, Any]) -> str:
info = payload["info"]
if info.get("request"):
return "WSGI_{}_{}".format(
@@ -111,9 +114,10 @@ class OTLP(base.Driver):
return "REQUESTS_{}_{}".format(
info["requests"]["method"], info["requests"]["hostname"]
)
return payload["name"].rstrip("-start")
return str(payload["name"]).rstrip("-start")
def notify(self, payload):
def notify(self, info: dict[str, Any], **kwargs: Any) -> None:
payload = info
if payload["name"].endswith("start"):
parent = self.trace_api.SpanContext(
trace_id=utils.uuid_to_int128(payload["base_id"]),
@@ -136,12 +140,12 @@ class OTLP(base.Driver):
context=ctx,
)
span._context = self.trace_api.SpanContext(
trace_id=span.context.trace_id,
span._context = self.trace_api.SpanContext( # type: ignore[attr-defined]
trace_id=span.context.trace_id, # type: ignore[attr-defined]
span_id=utils.shorten_id(payload["trace_id"]),
is_remote=span.context.is_remote,
trace_flags=span.context.trace_flags,
trace_state=span.context.trace_state,
is_remote=span.context.is_remote, # type: ignore[attr-defined]
trace_flags=span.context.trace_flags, # type: ignore[attr-defined]
trace_state=span.context.trace_state, # type: ignore[attr-defined]
)
self.spans.append(span)
@@ -171,16 +175,18 @@ class OTLP(base.Driver):
)
span.end()
def get_report(self, base_id):
def get_report(self, base_id: str) -> dict[str, Any]:
return self._parse_results()
def list_traces(self, fields=None):
def list_traces(
self, fields: set[str] | None = None
) -> list[dict[str, Any]]:
return []
def list_error_traces(self):
def list_error_traces(self) -> list[dict[str, Any]]:
return []
def create_span_tags(self, payload):
def create_span_tags(self, payload: dict[str, Any]) -> dict[str, Any]:
"""Create tags an OpenTracing compatible span.
:param info: Information from OSProfiler trace.
@@ -188,7 +194,7 @@ class OTLP(base.Driver):
from OpenTracing sematic conventions,
and some other custom tags related to http, db calls.
"""
tags = {}
tags: dict[str, Any] = {}
info = payload["info"]
if info.get("db"):
+59 -39
View File
@@ -14,6 +14,8 @@
# License for the specific language governing permissions and limitations
# under the License.
from collections.abc import Generator
from typing import Any, cast
from urllib import parse as parser
from debtcollector import removals
@@ -25,7 +27,7 @@ from osprofiler import exc
class Redis(base.Driver):
@removals.removed_kwarg(
@removals.removed_kwarg( # type: ignore[untyped-decorator]
"db",
message="'db' parameter is deprecated "
"and will be removed in future. "
@@ -34,14 +36,14 @@ class Redis(base.Driver):
)
def __init__(
self,
connection_str,
db=0,
project=None,
service=None,
host=None,
conf=cfg.CONF,
**kwargs,
):
connection_str: str,
db: int = 0,
project: str | None = None,
service: str | None = None,
host: str | None = None,
conf: cfg.ConfigOpts = cfg.CONF,
**kwargs: Any,
) -> None:
"""Redis driver for OSProfiler."""
super().__init__(
@@ -69,10 +71,10 @@ class Redis(base.Driver):
self.namespace_error = "osprofiler_error:"
@classmethod
def get_name(cls):
def get_name(cls) -> str:
return "redis"
def notify(self, info):
def notify(self, info: dict[str, Any], **kwargs: Any) -> None:
"""Send notifications to Redis.
:param info: Contains information about trace element.
@@ -97,7 +99,7 @@ class Redis(base.Driver):
):
self.notify_error_trace(data)
def notify_error_trace(self, data):
def notify_error_trace(self, data: dict[str, Any]) -> None:
"""Store base_id and timestamp of error trace to a separate key."""
key = self.namespace_error + data["base_id"]
value = jsonutils.dumps(
@@ -105,7 +107,9 @@ class Redis(base.Driver):
)
self.db.set(key, value)
def list_traces(self, fields=None):
def list_traces(
self, fields: set[str] | None = None
) -> list[dict[str, Any]]:
"""Query all traces from the storage.
:param fields: Set of trace fields to return. Defaults to 'base_id'
@@ -122,7 +126,10 @@ class Redis(base.Driver):
ids = self.db.scan_iter(match=self.namespace_opt + "*")
for i in ids:
# for each trace query the first event to have a timestamp
first_event = jsonutils.loads(self.db.lindex(i, 1))
raw = cast(bytes | None, self.db.lindex(i, 1))
if raw is None:
continue
first_event = jsonutils.loads(raw)
result.append(
{
key: value
@@ -132,15 +139,19 @@ class Redis(base.Driver):
)
return result
def _list_traces_legacy(self, fields):
def _list_traces_legacy(self, fields: set[str]) -> list[dict[str, Any]]:
# With current schema every event is stored under its own unique key
# To query all traces we first need to get all keys, then
# get all events, sort them and pick up only the first one
ids = self.db.scan_iter(match=self.namespace + "*")
traces = [jsonutils.loads(self.db.get(i)) for i in ids]
traces = [
jsonutils.loads(raw)
for i in ids
if (raw := cast(bytes | None, self.db.get(i))) is not None
]
traces.sort(key=lambda x: x["timestamp"])
seen_ids = set()
result = []
seen_ids: set[str] = set()
result: list[dict[str, Any]] = []
for trace in traces:
if trace["base_id"] not in seen_ids:
seen_ids.add(trace["base_id"])
@@ -153,13 +164,17 @@ class Redis(base.Driver):
)
return result
def list_error_traces(self):
def list_error_traces(self) -> list[dict[str, Any]]:
"""Returns all traces that have error/exception."""
ids = self.db.scan_iter(match=self.namespace_error + "*")
traces = [jsonutils.loads(self.db.get(i)) for i in ids]
traces = [
jsonutils.loads(raw)
for i in ids
if (raw := cast(bytes | None, self.db.get(i))) is not None
]
traces.sort(key=lambda x: x["timestamp"])
seen_ids = set()
result = []
seen_ids: set[str] = set()
result: list[dict[str, Any]] = []
for trace in traces:
if trace["base_id"] not in seen_ids:
seen_ids.add(trace["base_id"])
@@ -167,19 +182,24 @@ class Redis(base.Driver):
return result
def get_report(self, base_id):
def get_report(self, base_id: str) -> dict[str, Any]:
"""Retrieves and parses notification from Redis.
:param base_id: Base id of trace elements.
"""
def iterate_events():
def iterate_events() -> Generator[bytes, None, None]:
for key in self.db.scan_iter(
match=self.namespace + base_id + "*"
): # legacy
yield self.db.get(key)
data = cast(bytes | None, self.db.get(key))
if data is not None:
yield data
yield from self.db.lrange(self.namespace_opt + base_id, 0, -1)
yield from cast(
list[bytes],
self.db.lrange(self.namespace_opt + base_id, 0, -1),
)
for data in iterate_events():
n = jsonutils.loads(data)
@@ -199,7 +219,7 @@ class Redis(base.Driver):
class RedisSentinel(Redis, base.Driver):
@removals.removed_kwarg(
@removals.removed_kwarg( # type: ignore[untyped-decorator]
"db",
message="'db' parameter is deprecated "
"and will be removed in future. "
@@ -208,14 +228,14 @@ class RedisSentinel(Redis, base.Driver):
)
def __init__(
self,
connection_str,
db=0,
project=None,
service=None,
host=None,
conf=cfg.CONF,
**kwargs,
):
connection_str: str,
db: int = 0,
project: str | None = None,
service: str | None = None,
host: str | None = None,
conf: cfg.ConfigOpts = cfg.CONF,
**kwargs: Any,
) -> None:
"""Redis driver for OSProfiler."""
super().__init__(
@@ -238,16 +258,16 @@ class RedisSentinel(Redis, base.Driver):
self.conf = conf
socket_timeout = self.conf.profiler.socket_timeout
parsed_url = parser.urlparse(self.connection_str)
sentinel = Sentinel(
[(parsed_url.hostname, int(parsed_url.port))],
sentinel = Sentinel( # type: ignore[no-untyped-call]
[(parsed_url.hostname, int(parsed_url.port))], # type: ignore[arg-type]
password=parsed_url.password,
socket_timeout=socket_timeout,
)
self.db = sentinel.master_for(
self.db = sentinel.master_for( # type: ignore[no-untyped-call]
self.conf.profiler.sentinel_service_name,
socket_timeout=socket_timeout,
)
@classmethod
def get_name(cls):
def get_name(cls) -> str:
return "redissentinel"
+21 -10
View File
@@ -14,6 +14,7 @@
# under the License.
import logging
from typing import Any
from oslo_serialization import jsonutils
@@ -25,14 +26,19 @@ LOG = logging.getLogger(__name__)
class SQLAlchemyDriver(base.Driver):
def __init__(
self, connection_str, project=None, service=None, host=None, **kwargs
):
self,
connection_str: str,
project: str | None = None,
service: str | None = None,
host: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(
connection_str, project=project, service=service, host=host
)
try:
from sqlalchemy import create_engine
from sqlalchemy import create_engine # type: ignore[import-not-found]
from sqlalchemy import Table, MetaData, Column
from sqlalchemy import String, JSON, Integer
except ImportError:
@@ -75,10 +81,12 @@ class SQLAlchemyDriver(base.Driver):
)
@classmethod
def get_name(cls):
def get_name(cls) -> str:
return "sqlalchemy"
def notify(self, info, context=None):
def notify(
self, info: dict[str, Any], context: Any = None, **kwargs: Any
) -> None:
"""Write a notification the the database"""
data = info.copy()
base_id = data.pop("base_id", None)
@@ -110,16 +118,19 @@ class SQLAlchemyDriver(base.Driver):
base_id,
)
def list_traces(self, fields=None):
def list_traces(
self, fields: set[str] | None = None
) -> list[dict[str, Any]]:
try:
from sqlalchemy.sql import select
from sqlalchemy.sql import select # type: ignore[import-not-found]
except ImportError:
raise exc.CommandError(
"To use this command, you should install 'SQLAlchemy'"
)
fields = set(fields or self.default_trace_fields)
stmt = select([self._data_table])
seen_ids = set()
result = []
seen_ids: set[str] = set()
result: list[dict[str, Any]] = []
traces = self._conn.execute(stmt).fetchall()
for trace in traces:
if trace["base_id"] not in seen_ids:
@@ -133,7 +144,7 @@ class SQLAlchemyDriver(base.Driver):
)
return result
def get_report(self, base_id):
def get_report(self, base_id: str) -> dict[str, Any]:
try:
from sqlalchemy.sql import select
except ImportError:
+3 -3
View File
@@ -17,11 +17,11 @@
class CommandError(Exception):
"""Invalid usage of CLI."""
def __init__(self, message=None):
def __init__(self, message: str | None = None) -> None:
self.message = message
def __str__(self):
return self.message or self.__class__.__doc__
def __str__(self) -> str:
return self.message or self.__class__.__doc__ or ""
class LogInsightAPIError(Exception):
+48 -28
View File
@@ -27,8 +27,10 @@ Guidelines for writing new hacking checks
import functools
import re
import tokenize
from collections.abc import Callable, Generator
from typing import Any
from hacking import core
from hacking import core # type: ignore[import-not-found]
re_no_construct_dict = re.compile(r"\sdict\(\)")
re_no_construct_list = re.compile(r"\slist\(\)")
@@ -47,11 +49,15 @@ re_str_format = re.compile(
re_raises = re.compile(r"\s:raise[^s] *.*$|\s:raises *:.*$|\s:raises *[^:]+$")
@core.flake8ext
def skip_ignored_lines(func):
@core.flake8ext # type: ignore[untyped-decorator]
def skip_ignored_lines(
func: Callable[..., Any],
) -> Callable[..., Any]:
@functools.wraps(func)
def wrapper(logical_line, filename):
def wrapper(
logical_line: str, filename: str
) -> Generator[tuple[int, str], None, None]:
line = logical_line.strip()
if not line or line.startswith("#") or line.endswith("# noqa"):
return
@@ -63,7 +69,9 @@ def skip_ignored_lines(func):
return wrapper
def _parse_assert_mock_str(line):
def _parse_assert_mock_str(
line: str,
) -> tuple[int, str, str] | tuple[None, None, None]:
point = line.find(".assert_")
if point != -1:
@@ -73,9 +81,11 @@ def _parse_assert_mock_str(line):
return None, None, None
@skip_ignored_lines
@core.flake8ext
def check_assert_methods_from_mock(logical_line, filename):
@skip_ignored_lines # type: ignore[untyped-decorator]
@core.flake8ext # type: ignore[untyped-decorator]
def check_assert_methods_from_mock(
logical_line: str, filename: str
) -> Generator[tuple[int, str], None, None]:
"""Ensure that ``assert_*`` methods from ``mock`` library is used correctly
N301 - base error number
@@ -135,9 +145,17 @@ def check_assert_methods_from_mock(logical_line, filename):
)
@skip_ignored_lines
@core.flake8ext
def check_quotes(logical_line, filename):
def _check_triple(line: str, i: int, char: str) -> bool:
return i + 2 < len(line) and (
char == line[i] == line[i + 1] == line[i + 2]
)
@skip_ignored_lines # type: ignore[untyped-decorator]
@core.flake8ext # type: ignore[untyped-decorator]
def check_quotes(
logical_line: str, filename: str
) -> Generator[tuple[int, str], None, None]:
"""Check that single quotation marks are not used
N350
@@ -147,11 +165,6 @@ def check_quotes(logical_line, filename):
in_multiline_string = False
single_quotas_are_used = False
def check_tripple(line, i, char):
return i + 2 < len(line) and (
char == line[i] == line[i + 1] == line[i + 2]
)
i = 0
while i < len(logical_line):
char = logical_line[i]
@@ -163,7 +176,7 @@ def check_quotes(logical_line, filename):
i += 1 # ignore next char
elif in_multiline_string:
if check_tripple(logical_line, i, "\""):
if _check_triple(logical_line, i, "\""):
i += 2 # skip next 2 chars
in_multiline_string = False
@@ -175,7 +188,7 @@ def check_quotes(logical_line, filename):
break
elif char == "\"":
if check_tripple(logical_line, i, "\""):
if _check_triple(logical_line, i, "\""):
in_multiline_string = True
i += 3
continue
@@ -187,9 +200,11 @@ def check_quotes(logical_line, filename):
yield (i, "N350 Remove Single quotes")
@skip_ignored_lines
@core.flake8ext
def check_no_constructor_data_struct(logical_line, filename):
@skip_ignored_lines # type: ignore[untyped-decorator]
@core.flake8ext # type: ignore[untyped-decorator]
def check_no_constructor_data_struct(
logical_line: str, filename: str
) -> Generator[tuple[int, str], None, None]:
"""Check that data structs (lists, dicts) are declared using literals
N351
@@ -203,8 +218,10 @@ def check_no_constructor_data_struct(logical_line, filename):
yield (0, "N351 Remove list() construct and use literal []")
@core.flake8ext
def check_dict_formatting_in_string(logical_line, tokens):
@core.flake8ext # type: ignore[untyped-decorator]
def check_dict_formatting_in_string(
logical_line: str, tokens: Any
) -> Generator[tuple[int, str], None, None]:
"""Check that strings do not use dict-formatting with a single replacement
N352
@@ -275,9 +292,11 @@ def check_dict_formatting_in_string(logical_line, tokens):
current_string = ""
@skip_ignored_lines
@core.flake8ext
def check_using_unicode(logical_line, filename):
@skip_ignored_lines # type: ignore[untyped-decorator]
@core.flake8ext # type: ignore[untyped-decorator]
def check_using_unicode(
logical_line: str, filename: str
) -> Generator[tuple[int, str], None, None]:
"""Check crosspython unicode usage
N353
@@ -291,8 +310,8 @@ def check_using_unicode(logical_line, filename):
)
@core.flake8ext
def check_raises(physical_line, filename):
@core.flake8ext # type: ignore[untyped-decorator]
def check_raises(physical_line: str, filename: str) -> tuple[int, str] | None:
"""Check raises usage
N354
@@ -309,3 +328,4 @@ def check_raises(physical_line, filename):
"N354 ':Please use ':raises Exception: conditions' "
"in docstrings.",
)
return None
+12 -1
View File
@@ -13,12 +13,23 @@
# License for the specific language governing permissions and limitations
# under the License.
from typing import Any
from oslo_config import cfg
from osprofiler import notifier
from osprofiler import requests
from osprofiler import web
def init_from_conf(conf, context, project, service, host, **kwargs):
def init_from_conf(
conf: cfg.ConfigOpts,
context: Any,
project: str,
service: str,
host: str,
**kwargs: Any,
) -> None:
"""Initialize notifier from service configuration
:param conf: service configuration
+14 -8
View File
@@ -13,7 +13,9 @@
# License for the specific language governing permissions and limitations
# under the License.
from collections.abc import Callable
import logging
from typing import Any
from osprofiler.drivers import base
@@ -21,16 +23,18 @@ from osprofiler.drivers import base
LOG = logging.getLogger(__name__)
def _noop_notifier(info, context=None):
def _noop_notifier(info: dict[str, Any], context: Any = None) -> None:
"""Do nothing on notify()."""
# NOTE(boris-42): By default we are using noop notifier.
__notifier = _noop_notifier
__notifier_cache = {} # map: connection-string -> notifier
__notifier: Callable[..., None] = _noop_notifier
__notifier_cache: dict[
str, Callable[..., None]
] = {} # map: connection-string -> notifier
def notify(info):
def notify(info: dict[str, Any]) -> None:
"""Passes the profiling info to the notifier callable.
:param info: dictionary with profiling information
@@ -38,12 +42,12 @@ def notify(info):
__notifier(info)
def get():
def get() -> Callable[..., None]:
"""Returns notifier callable."""
return __notifier
def set(notifier):
def set(notifier: Callable[..., None]) -> None:
"""Service that are going to use profiler should set callable notifier.
Callable notifier is instance of callable object, that accept exactly
@@ -54,7 +58,9 @@ def set(notifier):
__notifier = notifier
def create(connection_string, *args, **kwargs):
def create(
connection_string: str, *args: Any, **kwargs: Any
) -> Callable[..., None]:
"""Create notifier based on specified plugin_name
:param connection_string: connection string which specifies the storage
@@ -83,5 +89,5 @@ def create(connection_string, *args, **kwargs):
return __notifier_cache[connection_string]
def clear_notifier_cache():
def clear_notifier_cache() -> None:
__notifier_cache.clear()
+20 -20
View File
@@ -247,17 +247,17 @@ cfg.CONF.register_opts(_OTLP_OPTS, group=_otlp_profiler_opt_group)
def set_defaults(
conf,
enabled=None,
trace_sqlalchemy=None,
hmac_keys=None,
connection_string=None,
es_doc_type=None,
es_scroll_time=None,
es_scroll_size=None,
socket_timeout=None,
sentinel_service_name=None,
):
conf: cfg.ConfigOpts,
enabled: bool | None = None,
trace_sqlalchemy: bool | None = None,
hmac_keys: str | None = None,
connection_string: str | None = None,
es_doc_type: str | None = None,
es_scroll_time: str | None = None,
es_scroll_size: int | None = None,
socket_timeout: float | None = None,
sentinel_service_name: str | None = None,
) -> None:
conf.register_opts(_PROFILER_OPTS, group=_profiler_opt_group)
if enabled is not None:
@@ -308,35 +308,35 @@ def set_defaults(
)
def is_trace_enabled(conf=None):
def is_trace_enabled(conf: cfg.ConfigOpts | None = None) -> bool:
if conf is None:
conf = cfg.CONF
return conf.profiler.enabled
return bool(conf.profiler.enabled)
def is_db_trace_enabled(conf=None):
def is_db_trace_enabled(conf: cfg.ConfigOpts | None = None) -> bool:
if conf is None:
conf = cfg.CONF
return conf.profiler.enabled and conf.profiler.trace_sqlalchemy
return bool(conf.profiler.enabled and conf.profiler.trace_sqlalchemy)
def enable_web_trace(conf=None):
def enable_web_trace(conf: cfg.ConfigOpts | None = None) -> None:
if conf is None:
conf = cfg.CONF
if conf.profiler.enabled:
web.enable(conf.profiler.hmac_keys)
def disable_web_trace(conf=None):
def disable_web_trace(conf: cfg.ConfigOpts | None = None) -> None:
if conf is None:
conf = cfg.CONF
if conf.profiler.enabled:
web.disable()
def list_opts():
def list_opts() -> list[tuple[str, list[cfg.Opt]]]:
return [
(_profiler_opt_group.name, _PROFILER_OPTS),
(_jaegerprofiler_opt_group, _JAEGER_OPTS),
(_otlp_profiler_opt_group, _OTLP_OPTS),
(_jaegerprofiler_opt_group.name, _JAEGER_OPTS),
(_otlp_profiler_opt_group.name, _OTLP_OPTS),
]
+81 -49
View File
@@ -14,10 +14,13 @@
# under the License.
import collections
from collections.abc import Callable
import functools
import inspect
import socket
import threading
import types
from typing import Any, ParamSpec, TypeVar, cast
from oslo_utils import reflection
from oslo_utils import timeutils
@@ -27,15 +30,21 @@ from osprofiler import _utils as utils
from osprofiler import notifier
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T", bound=type)
# NOTE(boris-42): Thread safe storage for profiler instances.
__local_ctx = threading.local()
def clean():
def clean() -> None:
__local_ctx.profiler = None
def _ensure_no_multiple_traced(traceable_attrs):
def _ensure_no_multiple_traced(
traceable_attrs: list[tuple[str, Any]],
) -> None:
for attr_name, attr in traceable_attrs:
traced_times = getattr(attr, "__traced__", 0)
if traced_times:
@@ -46,7 +55,11 @@ def _ensure_no_multiple_traced(traceable_attrs):
)
def init(hmac_key, base_id=None, parent_id=None):
def init(
hmac_key: str,
base_id: str | None = None,
parent_id: str | None = None,
) -> "_Profiler":
"""Init profiler instance for current thread.
You should call profiler.init() before using osprofiler.
@@ -61,10 +74,10 @@ def init(hmac_key, base_id=None, parent_id=None):
__local_ctx.profiler = _Profiler(
hmac_key, base_id=base_id, parent_id=parent_id
)
return __local_ctx.profiler
return cast("_Profiler", __local_ctx.profiler)
def get():
def get() -> "_Profiler | None":
"""Get profiler instance.
:returns: Profiler instance or None if profiler wasn't inited.
@@ -72,7 +85,7 @@ def get():
return getattr(__local_ctx, "profiler", None)
def start(name, info=None):
def start(name: str, info: dict[str, Any] | None = None) -> None:
"""Send new start notification if profiler instance is presented.
:param name: The name of action. E.g. wsgi, rpc, db, etc..
@@ -84,7 +97,7 @@ def start(name, info=None):
profiler.start(name, info=info)
def stop(info=None):
def stop(info: dict[str, Any] | None = None) -> None:
"""Send new stop notification if profiler instance is presented."""
profiler = get()
if profiler:
@@ -92,12 +105,12 @@ def stop(info=None):
def trace(
name,
info=None,
hide_args=False,
hide_result=True,
allow_multiple_trace=True,
):
name: str,
info: dict[str, Any] | None = None,
hide_args: bool = False,
hide_result: bool = True,
allow_multiple_trace: bool = True,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Trace decorator for functions.
Very useful if you would like to add trace point on existing function:
@@ -126,7 +139,7 @@ def trace(
info = info.copy()
info["function"] = {}
def decorator(f):
def decorator(f: Callable[P, R]) -> Callable[P, R]:
trace_times = getattr(f, "__traced__", 0)
if not allow_multiple_trace and trace_times:
raise ValueError(
@@ -134,19 +147,19 @@ def trace(
)
try:
f.__traced__ = trace_times + 1
setattr(f, "__traced__", trace_times + 1)
except AttributeError:
# Tries to work around the following:
#
# AttributeError: 'instancemethod' object has no
# attribute '__traced__'
try:
f.im_func.__traced__ = trace_times + 1
setattr(getattr(f, "im_func"), "__traced__", trace_times + 1)
except AttributeError: # nosec
pass
@functools.wraps(f)
def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# NOTE(tovin07): Workaround for this issue
# F823 local variable 'info'
# (defined in enclosing scope on line xxx)
@@ -161,7 +174,7 @@ def trace(
info_["function"]["args"] = str(args)
info_["function"]["kwargs"] = str(kwargs)
stop_info = None
stop_info: dict[str, Any] | None = None
try:
start(name, info=info_)
result = f(*args, **kwargs)
@@ -184,15 +197,15 @@ def trace(
def trace_cls(
name,
info=None,
hide_args=False,
hide_result=True,
trace_private=False,
allow_multiple_trace=True,
trace_class_methods=False,
trace_static_methods=False,
):
name: str,
info: dict[str, Any] | None = None,
hide_args: bool = False,
hide_result: bool = True,
trace_private: bool = False,
allow_multiple_trace: bool = True,
trace_class_methods: bool = False,
trace_static_methods: bool = False,
) -> Callable[[T], T]:
"""Trace decorator for instances of class .
Very useful if you would like to add trace point on existing method:
@@ -230,7 +243,9 @@ def trace_cls(
tracing is not allowed (by default allow).
"""
def trace_checker(attr_name, to_be_wrapped):
def trace_checker(
attr_name: str, to_be_wrapped: Any
) -> tuple[bool, type | None]:
if attr_name.startswith("__"):
# Never trace really private methods.
return (False, None)
@@ -246,11 +261,11 @@ def trace_cls(
return (True, classmethod)
return (True, None)
def decorator(cls):
def decorator(cls: T) -> T:
clss = cls if inspect.isclass(cls) else cls.__class__
mro_dicts = [c.__dict__ for c in inspect.getmro(clss)]
traceable_attrs = []
traceable_wrappers = []
traceable_attrs: list[tuple[str, Any]] = []
traceable_wrappers: list[type | None] = []
for attr_name, attr in inspect.getmembers(cls):
if not (inspect.ismethod(attr) or inspect.isfunction(attr)):
continue
@@ -305,7 +320,12 @@ class TracedMeta(type):
traced - E.g. wsgi, rpc, db, etc...
"""
def __init__(cls, cls_name, bases, attrs):
def __init__(
cls,
cls_name: str,
bases: tuple[type, ...],
attrs: dict[str, Any],
) -> None:
super().__init__(cls_name, bases, attrs)
trace_args = dict(getattr(cls, "__trace_args__", {}))
@@ -318,7 +338,7 @@ class TracedMeta(type):
"e.g. __trace_args__ = {'name': 'rpc'}"
)
traceable_attrs = []
traceable_attrs: list[tuple[str, Any]] = []
for attr_name, attr_value in attrs.items():
if not (
inspect.ismethod(attr_value) or inspect.isfunction(attr_value)
@@ -340,7 +360,7 @@ class TracedMeta(type):
class Trace:
def __init__(self, name, info=None):
def __init__(self, name: str, info: dict[str, Any] | None = None) -> None:
"""With statement way to use profiler start()/stop().
@@ -358,12 +378,17 @@ class Trace:
self._name = name
self._info = info
def __enter__(self):
def __enter__(self) -> None:
start(self._name, info=self._info)
def __exit__(self, etype, value, traceback):
def __exit__(
self,
etype: type[BaseException] | None,
value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
info = None
if etype:
if etype and value is not None:
info = {
"etype": reflection.get_class_name(etype),
"message": value.args[0] if value.args else None,
@@ -372,15 +397,22 @@ class Trace:
class _Profiler:
def __init__(self, hmac_key, base_id=None, parent_id=None):
def __init__(
self,
hmac_key: str,
base_id: str | None = None,
parent_id: str | None = None,
) -> None:
self.hmac_key = hmac_key
if not base_id:
base_id = str(uuidutils.generate_uuid())
self._trace_stack = collections.deque([base_id, parent_id or base_id])
self._name = collections.deque()
self._host = socket.gethostname()
self._trace_stack: collections.deque[str] = collections.deque(
[base_id, parent_id or base_id]
)
self._name: collections.deque[str] = collections.deque()
self._host: str = socket.gethostname()
def get_shorten_id(self, uuid_id):
def get_shorten_id(self, uuid_id: str | int) -> str:
"""Return shorten id of a uuid that will be used in OpenTracing drivers
:param uuid_id: A string of uuid that was generated by uuidutils
@@ -388,7 +420,7 @@ class _Profiler:
"""
return format(utils.shorten_id(uuid_id), "x")
def get_base_id(self):
def get_base_id(self) -> str:
"""Return base id of a trace.
Base id is the same for all elements in one trace. It's main goal is
@@ -396,15 +428,15 @@ class _Profiler:
"""
return self._trace_stack[0]
def get_parent_id(self):
def get_parent_id(self) -> str:
"""Returns parent trace element id."""
return self._trace_stack[-2]
def get_id(self):
def get_id(self) -> str:
"""Returns current trace element id."""
return self._trace_stack[-1]
def start(self, name, info=None):
def start(self, name: str, info: dict[str, Any] | None = None) -> None:
"""Start new event.
Adds new trace_id to trace stack and sends notification
@@ -424,7 +456,7 @@ class _Profiler:
self._trace_stack.append(str(uuidutils.generate_uuid()))
self._notify(f"{name}-start", info)
def stop(self, info=None):
def stop(self, info: dict[str, Any] | None = None) -> None:
"""Finish latest event.
Same as a start, but instead of pushing trace_id to stack it pops it.
@@ -436,8 +468,8 @@ class _Profiler:
self._notify(f"{self._name.pop()}-stop", info)
self._trace_stack.pop()
def _notify(self, name, info):
payload = {
def _notify(self, name: str, info: dict[str, Any]) -> None:
payload: dict[str, Any] = {
"name": name,
"base_id": self.get_base_id(),
"trace_id": self.get_id(),
View File
+9 -5
View File
@@ -12,7 +12,9 @@
# License for the specific language governing permissions and limitations
# under the License.
from collections.abc import Callable
import logging as log
from typing import Any
from urllib import parse as parser
from osprofiler import profiler
@@ -24,7 +26,7 @@ from osprofiler import web
LOG = log.getLogger(__name__)
_FUNC = None
_FUNC: Callable[..., Any] | None = None
try:
from requests.adapters import HTTPAdapter
@@ -32,11 +34,11 @@ except ImportError:
pass
else:
def send(self, request, *args, **kwargs):
def send(self: Any, request: Any, *args: Any, **kwargs: Any) -> Any:
parsed_url = parser.urlparse(request.url)
# Best effort guessing port if needed
port = parsed_url.port or ""
port: int | str = parsed_url.port or ""
if not port and parsed_url.scheme == "http":
port = 80
elif not port and parsed_url.scheme == "https":
@@ -60,6 +62,8 @@ else:
# context/span.
request.headers.update(web.get_trace_id_headers())
if _FUNC is None:
raise RuntimeError("osprofiler requests adapter not initialized")
response = _FUNC(self, request, *args, **kwargs)
profiler.stop(info={"requests": {"status_code": response.status_code}})
@@ -69,9 +73,9 @@ else:
_FUNC = HTTPAdapter.send
def enable():
def enable() -> None:
if _FUNC:
HTTPAdapter.send = send
HTTPAdapter.send = send # type: ignore[method-assign]
LOG.debug("profiling requests enabled")
else:
LOG.warning(
+27 -9
View File
@@ -13,8 +13,10 @@
# License for the specific language governing permissions and limitations
# under the License.
from collections.abc import Generator
import contextlib
import logging as log
from typing import Any
from oslo_utils import reflection
@@ -25,20 +27,22 @@ LOG = log.getLogger(__name__)
_DISABLED = False
def disable():
def disable() -> None:
"""Disable tracing of all DB queries. Reduce a lot size of profiles."""
global _DISABLED
_DISABLED = True
def enable():
def enable() -> None:
"""add_tracing adds event listeners for sqlalchemy."""
global _DISABLED
_DISABLED = False
def add_tracing(sqlalchemy, engine, name, hide_result=True):
def add_tracing(
sqlalchemy: Any, engine: Any, name: str, hide_result: bool = True
) -> None:
"""Add tracing to all sqlalchemy calls."""
if not _DISABLED:
@@ -54,7 +58,7 @@ def add_tracing(sqlalchemy, engine, name, hide_result=True):
@contextlib.contextmanager
def wrap_session(sqlalchemy, sess):
def wrap_session(sqlalchemy: Any, sess: Any) -> Generator[Any, None, None]:
with sess as s:
if not getattr(s.bind, "traced", False):
add_tracing(sqlalchemy, s.bind, "db")
@@ -62,17 +66,24 @@ def wrap_session(sqlalchemy, sess):
yield s
def _before_cursor_execute(name):
def _before_cursor_execute(name: str) -> Any:
"""Add listener that will send trace info before query is executed."""
def handler(conn, cursor, statement, params, context, executemany):
def handler(
conn: Any,
cursor: Any,
statement: Any,
params: Any,
context: Any,
executemany: Any,
) -> None:
info = {"db": {"statement": statement, "params": params}}
profiler.start(name, info=info)
return handler
def _after_cursor_execute(hide_result=True):
def _after_cursor_execute(hide_result: bool = True) -> Any:
"""Add listener that will send trace info after query is executed.
:param hide_result: Boolean value to hide or show SQL result in trace.
@@ -80,7 +91,14 @@ def _after_cursor_execute(hide_result=True):
False - show SQL result in trace.
"""
def handler(conn, cursor, statement, params, context, executemany):
def handler(
conn: Any,
cursor: Any,
statement: Any,
params: Any,
context: Any,
executemany: Any,
) -> None:
if not hide_result:
# Add SQL result to trace info in *-stop phase
info = {"db": {"result": str(cursor._rows)}}
@@ -91,7 +109,7 @@ def _after_cursor_execute(hide_result=True):
return handler
def handle_error(exception_context):
def handle_error(exception_context: Any) -> None:
"""Handle SQLAlchemy errors"""
exception_class_name = reflection.get_class_name(
exception_context.original_exception
+6 -2
View File
@@ -86,7 +86,9 @@ class DriverTestCase(test.FunctionalTestCase):
profiler.init("SECRET_KEY")
# grab base_id
base_id = profiler.get().get_base_id()
p = profiler.get()
assert p is not None # noqa: S101
base_id = p.get_base_id()
# execute profiled code
foo = Foo()
@@ -150,7 +152,9 @@ class RedisDriverTestCase(DriverTestCase):
profiler.init("SECRET_KEY")
# grab base_id
base_id = profiler.get().get_base_id()
p = profiler.get()
assert p is not None # noqa: S101
base_id = p.get_base_id()
# execute profiled code
foo = Foo()
+8 -4
View File
@@ -17,6 +17,7 @@ import io
import json
import os
import sys
from typing import cast
from unittest import mock
import ddt
@@ -36,7 +37,8 @@ class ShellTestCase(test.TestCase):
def tearDown(self):
super().tearDown()
os.environ = self.old_environment
os.environ.clear()
os.environ.update(self.old_environment)
def _trace_show_cmd(self, format_=None):
cmd = f"trace show --connection-string redis:// {self.TRACE_ID}"
@@ -47,7 +49,9 @@ class ShellTestCase(test.TestCase):
def test_shell_main(self, mock_shell):
mock_shell.side_effect = exc.CommandError("some_message")
shell.main()
self.assertEqual("some_message\n", sys.stdout.getvalue())
self.assertEqual(
"some_message\n", cast(io.StringIO, sys.stdout).getvalue()
)
def run_command(self, cmd):
shell.OSProfilerShell(cmd.split())
@@ -117,7 +121,7 @@ class ShellTestCase(test.TestCase):
separators=(",", ": "),
)
),
sys.stdout.getvalue(),
cast(io.StringIO, sys.stdout).getvalue(),
)
@mock.patch("sys.stdout", io.StringIO())
@@ -153,7 +157,7 @@ class ShellTestCase(test.TestCase):
"\n".format(
json.dumps(notifications, indent=4, separators=(",", ": "))
),
sys.stdout.getvalue(),
cast(io.StringIO, sys.stdout).getvalue(),
)
@mock.patch("sys.stdout", io.StringIO())
+4 -4
View File
@@ -25,10 +25,10 @@ class NotifierBaseTestCase(test.TestCase):
def get_name(cls):
return "a"
def notify(self, a):
def notify(self, a): # type: ignore[override]
return a
self.assertEqual(10, base.get_driver("a://").notify(10))
self.assertEqual(10, base.get_driver("a://").notify(10)) # type: ignore[arg-type, func-returns-value]
def test_factory_with_args(self):
@@ -41,10 +41,10 @@ class NotifierBaseTestCase(test.TestCase):
def get_name(cls):
return "b"
def notify(self, c):
def notify(self, c): # type: ignore[override]
return self.a + self.b + c
self.assertEqual(22, base.get_driver("b://", 5, b=7).notify(10))
self.assertEqual(22, base.get_driver("b://", 5, b=7).notify(10)) # type: ignore[arg-type, func-returns-value]
def test_driver_not_found(self):
self.assertRaises(
@@ -14,6 +14,7 @@
# License for the specific language governing permissions and limitations
# under the License.
from typing import Any
from unittest import mock
from oslo_serialization import jsonutils
@@ -131,7 +132,7 @@ class RedisParserTestCase(test.TestCase):
def test_get_report(self):
self.redisdb.db = mock.MagicMock()
result_elements = [
result_elements: list[dict[str, Any]] = [
{
"info": {
"project": None,
+1 -1
View File
@@ -27,7 +27,7 @@ class InitializerTestCase(testtools.TestCase):
conf = mock.Mock()
conf.profiler.connection_string = "driver://"
conf.profiler.hmac_keys = "hmac_keys"
context = {}
context: dict[object, object] = {}
project = "my-project"
service = "my-service"
host = "my-host"
+1 -1
View File
@@ -39,7 +39,7 @@ class NotifierTestCase(test.TestCase):
def test_notify(self):
m = mock.MagicMock()
notifier.set(m)
notifier.notify(10)
notifier.notify(10) # type: ignore[arg-type]
m.assert_called_once_with(10)
+22 -10
View File
@@ -17,6 +17,8 @@ import collections
import copy
import datetime
import re
import unittest
from typing import Any, ClassVar
from unittest import mock
@@ -43,8 +45,8 @@ class ProfilerGlobMethodsTestCase(test.TestCase):
def test_start(self):
p = profiler.init("secret", base_id="1", parent_id="2")
p.start = mock.MagicMock()
profiler.start("name", info="info")
p.start = mock.MagicMock() # type: ignore[method-assign]
profiler.start("name", info="info") # type: ignore[arg-type]
p.start.assert_called_once_with("name", info="info")
def test_stop_not_inited(self):
@@ -53,8 +55,8 @@ class ProfilerGlobMethodsTestCase(test.TestCase):
def test_stop(self):
p = profiler.init("secret", base_id="1", parent_id="2")
p.stop = mock.MagicMock()
profiler.stop(info="info")
p.stop = mock.MagicMock() # type: ignore[method-assign]
profiler.stop(info="info") # type: ignore[arg-type]
p.stop.assert_called_once_with(info="info")
@@ -159,10 +161,10 @@ class WithTraceTestCase(test.TestCase):
@mock.patch("osprofiler.profiler.start")
def test_with_trace(self, mock_start, mock_stop):
with profiler.Trace("a", info="a1"):
with profiler.Trace("a", info="a1"): # type: ignore[arg-type]
mock_start.assert_called_once_with("a", info="a1")
mock_start.reset_mock()
with profiler.Trace("b", info="b1"):
with profiler.Trace("b", info="b1"): # type: ignore[arg-type]
mock_start.assert_called_once_with("b", info="b1")
mock_stop.assert_called_once_with(info=None)
mock_stop.reset_mock()
@@ -452,7 +454,7 @@ class TraceClsDecoratorTestCase(test.TestCase):
@mock.patch("osprofiler.profiler.stop")
@mock.patch("osprofiler.profiler.start")
@test.testcase.skip(
@unittest.skip(
"Static method tracing was disabled due the bug. This test should be "
"skipped until we find the way to address it."
)
@@ -499,7 +501,10 @@ class TraceClsDecoratorTestCase(test.TestCase):
class FakeTraceWithMetaclassBase(metaclass=profiler.TracedMeta):
__trace_args__ = {"name": "rpc", "info": {"a": 10}}
__trace_args__: ClassVar[dict[str, Any]] = {
"name": "rpc",
"info": {"a": 10},
}
def method1(self, a, b, c=10):
return a + b + c
@@ -520,14 +525,21 @@ class FakeTraceDummy(FakeTraceWithMetaclassBase):
class FakeTraceWithMetaclassHideArgs(FakeTraceWithMetaclassBase):
__trace_args__ = {"name": "a", "info": {"b": 20}, "hide_args": True}
__trace_args__: ClassVar[dict[str, Any]] = {
"name": "a",
"info": {"b": 20},
"hide_args": True,
}
def method5(self, k, l):
return k + l
class FakeTraceWithMetaclassPrivate(FakeTraceWithMetaclassBase):
__trace_args__ = {"name": "rpc", "trace_private": True}
__trace_args__: ClassVar[dict[str, Any]] = {
"name": "rpc",
"trace_private": True,
}
def _new_private_method(self, m):
return 2 * m
+2
View File
@@ -68,6 +68,7 @@ class UtilsTestCase(test.TestCase):
process_data = utils.signed_unpack(packed_data, hmac_data, [hmac])
self.assertIn("hmac_key", process_data)
assert process_data is not None # noqa: S101
process_data.pop("hmac_key")
self.assertEqual(data, process_data)
@@ -77,6 +78,7 @@ class UtilsTestCase(test.TestCase):
packed_data, hmac_data = utils.signed_pack(data, keys[-1])
process_data = utils.signed_unpack(packed_data, hmac_data, keys)
assert process_data is not None # noqa: S101
self.assertEqual(keys[-1], process_data["hmac_key"])
def test_signed_pack_unpack_many_wrong_keys(self):
+11 -9
View File
@@ -35,7 +35,7 @@ class WebTestCase(test.TestCase):
self.addCleanup(profiler.clean)
def test_get_trace_id_headers_no_hmac(self):
profiler.init(None, base_id="y", parent_id="z")
profiler.init(None, base_id="y", parent_id="z") # type: ignore[arg-type]
headers = web.get_trace_id_headers()
self.assertEqual(headers, {})
@@ -50,6 +50,7 @@ class WebTestCase(test.TestCase):
headers["X-Trace-Info"], headers["X-Trace-HMAC"], ["key"]
)
self.assertIn("hmac_key", trace_info)
assert trace_info is not None # noqa: S101
self.assertEqual("key", trace_info.pop("hmac_key"))
self.assertEqual({"parent_id": "z", "base_id": "y"}, trace_info)
@@ -91,9 +92,9 @@ class WebMiddlewareTestCase(test.TestCase):
request.get_response.return_value = "yeah!"
request.headers = headers
middleware = web.WsgiMiddleware("app", hmac_key, enabled=enabled)
middleware = web.WsgiMiddleware(mock.ANY, hmac_key, enabled=enabled)
self.assertEqual("yeah!", middleware(request))
request.get_response.assert_called_once_with("app")
request.get_response.assert_called_once_with(mock.ANY)
self.assertEqual(0, mock_profiler_init.call_count)
@mock.patch("osprofiler.web.profiler.init")
@@ -157,7 +158,8 @@ class WebMiddlewareTestCase(test.TestCase):
def test_wsgi_middleware_invalid_trace_info(self, mock_profiler_init):
hmac_key = "secret"
pack = utils.signed_pack(
[{"base_id": "1"}, {"parent_id": "2"}], hmac_key
[{"base_id": "1"}, {"parent_id": "2"}], # type: ignore[arg-type]
hmac_key,
)
headers = {
"a": "1",
@@ -191,7 +193,7 @@ class WebMiddlewareTestCase(test.TestCase):
}
middleware = web.WsgiMiddleware(
"app", f"secret1,{hmac_key}", enabled=True
mock.ANY, f"secret1,{hmac_key}", enabled=True
)
self.assertEqual("yeah!", middleware(request))
mock_profiler_init.assert_called_once_with(
@@ -220,7 +222,7 @@ class WebMiddlewareTestCase(test.TestCase):
}
middleware = web.WsgiMiddleware(
"app", f"{hmac_key},secret2", enabled=True
mock.ANY, f"{hmac_key},secret2", enabled=True
)
self.assertEqual("yeah!", middleware(request))
mock_profiler_init.assert_called_once_with(
@@ -249,7 +251,7 @@ class WebMiddlewareTestCase(test.TestCase):
"X-Trace-HMAC": pack[1],
}
middleware = web.WsgiMiddleware("app", hmac_key, enabled=True)
middleware = web.WsgiMiddleware(mock.ANY, hmac_key, enabled=True)
self.assertEqual("yeah!", middleware(request))
mock_profiler_init.assert_called_once_with(
hmac_key=hmac_key, base_id="1", parent_id="2"
@@ -269,7 +271,7 @@ class WebMiddlewareTestCase(test.TestCase):
request = mock.MagicMock()
request.get_response.return_value = "yeah!"
web.disable()
middleware = web.WsgiMiddleware("app", "hmac_key", enabled=True)
middleware = web.WsgiMiddleware(mock.ANY, "hmac_key", enabled=True)
self.assertEqual("yeah!", middleware(request))
self.assertEqual(mock_profiler_init.call_count, 0)
@@ -294,7 +296,7 @@ class WebMiddlewareTestCase(test.TestCase):
}
web.enable("super_secret_key1,super_secret_key2")
middleware = web.WsgiMiddleware("app", enabled=True)
middleware = web.WsgiMiddleware(mock.ANY, enabled=True)
self.assertEqual("yeah!", middleware(request))
mock_profiler_init.assert_called_once_with(
hmac_key=hmac_key, base_id="1", parent_id="2"
+35 -12
View File
@@ -13,11 +13,18 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import Any, TypeGuard, TYPE_CHECKING
import webob.dec
from osprofiler import _utils as utils
from osprofiler import profiler
if TYPE_CHECKING:
from _typeshed.wsgi import WSGIApplication
# Trace keys that are required or optional, any other
# keys that are present will cause the trace to be rejected...
@@ -31,21 +38,21 @@ X_TRACE_INFO = "X-Trace-Info"
X_TRACE_HMAC = "X-Trace-HMAC"
def get_trace_id_headers():
def get_trace_id_headers() -> dict[str, str]:
"""Adds the trace id headers (and any hmac) into provided dictionary."""
p = profiler.get()
if p and p.hmac_key:
data = {"base_id": p.get_base_id(), "parent_id": p.get_id()}
pack = utils.signed_pack(data, p.hmac_key)
return {X_TRACE_INFO: pack[0], X_TRACE_HMAC: pack[1]}
return {X_TRACE_INFO: pack[0].decode(), X_TRACE_HMAC: pack[1] or ""}
return {}
_ENABLED = None
_HMAC_KEYS = None
_ENABLED: bool | None = None
_HMAC_KEYS: list[str] | tuple[str, ...] | None = None
def disable():
def disable() -> None:
"""Disable middleware.
This is the alternative way to disable middleware. It will be used to be
@@ -55,7 +62,7 @@ def disable():
_ENABLED = False
def enable(hmac_keys=None):
def enable(hmac_keys: str | None = None) -> None:
"""Enable middleware."""
global _ENABLED, _HMAC_KEYS
_ENABLED = True
@@ -65,7 +72,13 @@ def enable(hmac_keys=None):
class WsgiMiddleware:
"""WSGI Middleware that enables tracing for an application."""
def __init__(self, application, hmac_keys=None, enabled=False, **kwargs):
def __init__(
self,
application: WSGIApplication,
hmac_keys: str | None = None,
enabled: bool = False,
**kwargs: Any,
) -> None:
"""Initialize middleware with api-paste.ini arguments.
:application: wsgi app
@@ -86,13 +99,17 @@ class WsgiMiddleware:
self.hmac_keys = utils.split(hmac_keys or "")
@classmethod
def factory(cls, global_conf, **local_conf):
def filter_(app):
def factory(
cls, global_conf: dict[str, Any] | None, **local_conf: Any
) -> Any:
def filter_(app: Any) -> WsgiMiddleware:
return cls(app, **local_conf)
return filter_
def _trace_is_valid(self, trace_info):
def _trace_is_valid(
self, trace_info: dict[str, Any] | None
) -> TypeGuard[dict[str, Any]]:
if not isinstance(trace_info, dict):
return False
trace_keys = set(trace_info.keys())
@@ -103,7 +120,9 @@ class WsgiMiddleware:
return True
@webob.dec.wsgify
def __call__(self, request):
def __call__(
self, request: webob.request.Request
) -> webob.response.Response:
if (
_ENABLED is not None
and not _ENABLED
@@ -121,7 +140,11 @@ class WsgiMiddleware:
if not self._trace_is_valid(trace_info):
return request.get_response(self.application)
profiler.init(**trace_info)
profiler.init(
hmac_key=trace_info["hmac_key"],
base_id=trace_info.get("base_id"),
parent_id=trace_info.get("parent_id"),
)
info = {
"request": {
"path": request.path,
+15
View File
@@ -65,6 +65,21 @@ packages = [
"osprofiler"
]
[tool.mypy]
python_version = "3.10"
show_column_numbers = true
show_error_context = true
strict = true
disable_error_code = ["import-untyped"]
exclude = "(?x)(doc | releasenotes)"
[[tool.mypy.overrides]]
module = ["osprofiler.tests.*"]
disallow_untyped_calls = false
disallow_untyped_defs = false
disallow_subclassing_any = false
disable_error_code = ["import-untyped"]
[tool.ruff]
line-length = 79
+21
View File
@@ -24,10 +24,31 @@ deps =
oslo.messaging
[testenv:pep8]
description =
Run style checks.
deps =
pre-commit
{[testenv:mypy]deps}
commands =
pre-commit run -a
{[testenv:mypy]commands}
[testenv:mypy]
description =
Run type checks.
deps =
{[testenv]deps}
mypy
types-PySocks
types-WebOb
types-netaddr
types-protobuf
types-psutil
types-python-dateutil
types-requests
types-setuptools
commands =
mypy --cache-dir="{envdir}/mypy_cache" {posargs:osprofiler}
[testenv:venv]
commands = {posargs}