Add typing
Change-Id: Ib00059b676a8fce50a8390b13b52ff5aa5805739 Signed-off-by: Stephen Finucane <stephenfin@redhat.com>
This commit is contained in:
+40
-15
@@ -19,11 +19,28 @@ import hmac
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Any, TypeVar, overload
|
||||
|
||||
from oslo_utils import uuidutils
|
||||
|
||||
_C = TypeVar("_C")
|
||||
_T = TypeVar("_T")
|
||||
|
||||
def split(text, strip=True):
|
||||
_AlreadySplit = TypeVar('_AlreadySplit', bound=list[Any] | tuple[Any, ...])
|
||||
|
||||
|
||||
@overload
|
||||
def split(text: str, strip: bool = True) -> list[str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def split(text: _AlreadySplit, strip: bool = True) -> _AlreadySplit: ...
|
||||
|
||||
|
||||
def split(
|
||||
text: str | _AlreadySplit, strip: bool = True
|
||||
) -> list[str] | _AlreadySplit:
|
||||
"""Splits a comma separated text blob into its components.
|
||||
|
||||
Does nothing if already a list or tuple.
|
||||
@@ -38,7 +55,7 @@ def split(text, strip=True):
|
||||
return text.split(",")
|
||||
|
||||
|
||||
def binary_encode(text, encoding="utf-8"):
|
||||
def binary_encode(text: bytes | str, encoding: str = "utf-8") -> bytes:
|
||||
"""Converts a string of into a binary type using given encoding.
|
||||
|
||||
Does nothing if text not unicode string.
|
||||
@@ -51,7 +68,7 @@ def binary_encode(text, encoding="utf-8"):
|
||||
raise TypeError("Expected binary or string type")
|
||||
|
||||
|
||||
def binary_decode(data, encoding="utf-8"):
|
||||
def binary_decode(data: bytes | str, encoding: str = "utf-8") -> str:
|
||||
"""Converts a binary type into a text type using given encoding.
|
||||
|
||||
Does nothing if data is already unicode string.
|
||||
@@ -64,14 +81,16 @@ def binary_decode(data, encoding="utf-8"):
|
||||
raise TypeError("Expected binary or string type")
|
||||
|
||||
|
||||
def generate_hmac(data, hmac_key):
|
||||
def generate_hmac(data: bytes | str, hmac_key: bytes | str) -> str:
|
||||
"""Generate a hmac using a known key given the provided content."""
|
||||
h = hmac.new(binary_encode(hmac_key), digestmod=hashlib.sha1)
|
||||
h.update(binary_encode(data))
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def signed_pack(data, hmac_key):
|
||||
def signed_pack(
|
||||
data: dict[str, str], hmac_key: str | None
|
||||
) -> tuple[bytes, str | None]:
|
||||
"""Pack and sign data with hmac_key."""
|
||||
raw_data = base64.urlsafe_b64encode(binary_encode(json.dumps(data)))
|
||||
|
||||
@@ -82,7 +101,11 @@ def signed_pack(data, hmac_key):
|
||||
return raw_data, generate_hmac(raw_data, hmac_key) if hmac_key else None
|
||||
|
||||
|
||||
def signed_unpack(data, hmac_data, hmac_keys):
|
||||
def signed_unpack(
|
||||
data: str | bytes | None,
|
||||
hmac_data: str | None,
|
||||
hmac_keys: Sequence[str] | None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Unpack data and check that it was signed with hmac_key.
|
||||
|
||||
:param data: json string that was singed_packed.
|
||||
@@ -95,7 +118,7 @@ def signed_unpack(data, hmac_data, hmac_keys):
|
||||
"""
|
||||
# NOTE(boris-42): For security reason, if there is no hmac_data or
|
||||
# hmac_keys we don't trust data => return None.
|
||||
if not (hmac_keys and hmac_data):
|
||||
if not hmac_keys or not hmac_data or not data:
|
||||
return None
|
||||
hmac_data = hmac_data.strip()
|
||||
if not hmac_data:
|
||||
@@ -108,7 +131,7 @@ def signed_unpack(data, hmac_data, hmac_keys):
|
||||
else:
|
||||
if hmac.compare_digest(hmac_data, user_hmac_data):
|
||||
try:
|
||||
contents = json.loads(
|
||||
contents: dict[str, Any] = json.loads(
|
||||
binary_decode(base64.urlsafe_b64decode(data))
|
||||
)
|
||||
contents["hmac_key"] = hmac_key
|
||||
@@ -118,14 +141,16 @@ def signed_unpack(data, hmac_data, hmac_keys):
|
||||
return None
|
||||
|
||||
|
||||
def itersubclasses(cls, _seen=None):
|
||||
def itersubclasses(
|
||||
cls: type[_C], _seen: set[type] | None = None
|
||||
) -> Generator[type[_C], None, None]:
|
||||
"""Generator over all subclasses of a given class in depth first order."""
|
||||
|
||||
_seen = _seen or set()
|
||||
try:
|
||||
subs = cls.__subclasses__()
|
||||
except TypeError: # fails only when cls is type
|
||||
subs = cls.__subclasses__(cls)
|
||||
subs = cls.__subclasses__(cls) # type: ignore[call-arg]
|
||||
for sub in subs:
|
||||
if sub not in _seen:
|
||||
_seen.add(sub)
|
||||
@@ -134,13 +159,13 @@ def itersubclasses(cls, _seen=None):
|
||||
yield sub
|
||||
|
||||
|
||||
def import_modules_from_package(package):
|
||||
def import_modules_from_package(package: str) -> None:
|
||||
"""Import modules from package and append into sys.modules
|
||||
|
||||
:param: package - Full package name. For example: rally.deploy.engines
|
||||
"""
|
||||
path = [os.path.dirname(__file__), ".."] + package.split(".")
|
||||
path = os.path.join(*path)
|
||||
path_parts = [os.path.dirname(__file__), ".."] + package.split(".")
|
||||
path = os.path.join(*path_parts)
|
||||
for root, dirs, files in os.walk(path):
|
||||
for filename in files:
|
||||
if filename.startswith("__") or not filename.endswith(".py"):
|
||||
@@ -150,7 +175,7 @@ def import_modules_from_package(package):
|
||||
__import__(module_name)
|
||||
|
||||
|
||||
def shorten_id(span_id):
|
||||
def shorten_id(span_id: str | int) -> int:
|
||||
"""Convert from uuid4 to 64 bit id for OpenTracing"""
|
||||
int64_max = (1 << 64) - 1
|
||||
if isinstance(span_id, int):
|
||||
@@ -163,7 +188,7 @@ def shorten_id(span_id):
|
||||
return short_id
|
||||
|
||||
|
||||
def uuid_to_int128(span_uuid):
|
||||
def uuid_to_int128(span_uuid: str | int) -> int:
|
||||
"""Convert from uuid4 to 128 bit id for OpenTracing"""
|
||||
if isinstance(span_uuid, int):
|
||||
return span_uuid
|
||||
|
||||
@@ -13,10 +13,12 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from collections.abc import Callable
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
||||
def env(*args, **kwargs):
|
||||
def env(*args: str, **kwargs: str) -> str:
|
||||
"""Returns the first environment variable set.
|
||||
|
||||
If all are empty, defaults to '' or keyword arg `default`.
|
||||
@@ -28,7 +30,9 @@ def env(*args, **kwargs):
|
||||
return kwargs.get("default", "")
|
||||
|
||||
|
||||
def arg(*args, **kwargs):
|
||||
def arg(
|
||||
*args: Any, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""Decorator for CLI args.
|
||||
|
||||
Example:
|
||||
@@ -38,22 +42,22 @@ def arg(*args, **kwargs):
|
||||
... pass
|
||||
"""
|
||||
|
||||
def _decorator(func):
|
||||
def _decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
add_arg(func, *args, **kwargs)
|
||||
return func
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
def add_arg(func, *args, **kwargs):
|
||||
def add_arg(func: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
|
||||
"""Bind CLI arguments to a shell.py `do_foo` function."""
|
||||
|
||||
if not hasattr(func, "arguments"):
|
||||
func.arguments = []
|
||||
setattr(func, "arguments", [])
|
||||
|
||||
# NOTE(sirp): avoid dups that can occur when the module is shared across
|
||||
# tests.
|
||||
if (args, kwargs) not in func.arguments:
|
||||
if (args, kwargs) not in getattr(func, "arguments"):
|
||||
# Because of the semantics of decorator composition if we just append
|
||||
# to the options list positional options will appear to be backwards.
|
||||
func.arguments.insert(0, (args, kwargs))
|
||||
getattr(func, "arguments").insert(0, (args, kwargs))
|
||||
|
||||
+13
-11
@@ -13,8 +13,10 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from oslo_utils import encodeutils
|
||||
from oslo_utils import uuidutils
|
||||
@@ -26,7 +28,7 @@ from osprofiler import exc
|
||||
|
||||
|
||||
class BaseCommand:
|
||||
group_name = None
|
||||
group_name: str | None = None
|
||||
|
||||
|
||||
class TraceCommands(BaseCommand):
|
||||
@@ -84,7 +86,7 @@ class TraceCommands(BaseCommand):
|
||||
help="filename for rendering the dot graph in pdf format",
|
||||
)
|
||||
@cliutils.arg("--out", dest="file_name", help="save output in file")
|
||||
def show(self, args):
|
||||
def show(self, args: argparse.Namespace) -> None:
|
||||
"""Display trace results in HTML, JSON or DOT format."""
|
||||
|
||||
if not args.conn_str:
|
||||
@@ -102,7 +104,7 @@ class TraceCommands(BaseCommand):
|
||||
try:
|
||||
engine = base.get_driver(args.conn_str, **args.__dict__)
|
||||
except Exception as e:
|
||||
raise exc.CommandError(e.message)
|
||||
raise exc.CommandError(str(e))
|
||||
|
||||
trace = engine.get_report(args.trace)
|
||||
|
||||
@@ -115,7 +117,7 @@ class TraceCommands(BaseCommand):
|
||||
|
||||
# Since datetime.datetime is not JSON serializable by default,
|
||||
# this method will handle that.
|
||||
def datetime_json_serialize(obj):
|
||||
def datetime_json_serialize(obj: Any) -> Any:
|
||||
if hasattr(obj, "isoformat"):
|
||||
return obj.isoformat()
|
||||
else:
|
||||
@@ -162,9 +164,9 @@ class TraceCommands(BaseCommand):
|
||||
else:
|
||||
print(output)
|
||||
|
||||
def _create_dot_graph(self, trace):
|
||||
def _create_dot_graph(self, trace: dict[str, Any]) -> Any:
|
||||
try:
|
||||
import graphviz
|
||||
import graphviz # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
raise exc.CommandError(
|
||||
"graphviz library is required to use this option."
|
||||
@@ -173,7 +175,7 @@ class TraceCommands(BaseCommand):
|
||||
dot = graphviz.Digraph(format="pdf")
|
||||
next_id = [0]
|
||||
|
||||
def _create_node(info):
|
||||
def _create_node(info: dict[str, Any]) -> str:
|
||||
time_taken = info["finished"] - info["started"]
|
||||
service = info["service"] + ":" if "service" in info else ""
|
||||
name = info["name"]
|
||||
@@ -194,7 +196,7 @@ class TraceCommands(BaseCommand):
|
||||
dot.node(node_id, label)
|
||||
return node_id
|
||||
|
||||
def _create_sub_graph(root):
|
||||
def _create_sub_graph(root: dict[str, Any]) -> str:
|
||||
rid = _create_node(root["info"])
|
||||
for child in root["children"]:
|
||||
cid = _create_sub_graph(child)
|
||||
@@ -218,7 +220,7 @@ class TraceCommands(BaseCommand):
|
||||
default=False,
|
||||
help="List all traces that contain error.",
|
||||
)
|
||||
def list(self, args):
|
||||
def list(self, args: argparse.Namespace) -> None:
|
||||
"""List all traces"""
|
||||
if not args.conn_str:
|
||||
raise exc.CommandError(
|
||||
@@ -229,13 +231,13 @@ class TraceCommands(BaseCommand):
|
||||
try:
|
||||
engine = base.get_driver(args.conn_str, **args.__dict__)
|
||||
except Exception as e:
|
||||
raise exc.CommandError(e.message)
|
||||
raise exc.CommandError(str(e))
|
||||
|
||||
fields = ("base_id", "timestamp")
|
||||
pretty_table = prettytable.PrettyTable(fields)
|
||||
pretty_table.align = "l"
|
||||
if not args.error_trace:
|
||||
traces = engine.list_traces(fields)
|
||||
traces = engine.list_traces(set(fields))
|
||||
else:
|
||||
traces = engine.list_error_traces()
|
||||
for trace in traces:
|
||||
|
||||
+11
-5
@@ -21,6 +21,7 @@ Command-line interface to the OpenStack Profiler.
|
||||
import argparse
|
||||
import inspect
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from oslo_config import cfg
|
||||
|
||||
@@ -31,13 +32,13 @@ from osprofiler import opts
|
||||
|
||||
|
||||
class OSProfilerShell:
|
||||
def __init__(self, argv):
|
||||
def __init__(self, argv: list[str]) -> None:
|
||||
args = self._get_base_parser().parse_args(argv)
|
||||
opts.set_defaults(cfg.CONF)
|
||||
|
||||
args.func(args)
|
||||
|
||||
def _get_base_parser(self):
|
||||
def _get_base_parser(self) -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="osprofiler", description=__doc__.strip(), add_help=True
|
||||
)
|
||||
@@ -50,9 +51,13 @@ class OSProfilerShell:
|
||||
|
||||
return parser
|
||||
|
||||
def _append_subcommands(self, parent_parser):
|
||||
def _append_subcommands(
|
||||
self, parent_parser: argparse.ArgumentParser
|
||||
) -> None:
|
||||
subcommands = parent_parser.add_subparsers(help="<subcommands>")
|
||||
for group_cls in commands.BaseCommand.__subclasses__():
|
||||
if group_cls.group_name is None:
|
||||
continue
|
||||
group_parser = subcommands.add_parser(group_cls.group_name)
|
||||
subcommand_parser = group_parser.add_subparsers()
|
||||
|
||||
@@ -71,7 +76,7 @@ class OSProfilerShell:
|
||||
command_parser.add_argument(*args, **kwargs)
|
||||
command_parser.set_defaults(func=callback)
|
||||
|
||||
def _no_project_and_domain_set(self, args):
|
||||
def _no_project_and_domain_set(self, args: Any) -> bool:
|
||||
if not (
|
||||
args.os_project_id
|
||||
or (
|
||||
@@ -85,7 +90,7 @@ class OSProfilerShell:
|
||||
return False
|
||||
|
||||
|
||||
def main(args=None):
|
||||
def main(args: list[str] | None = None) -> int | None:
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
|
||||
@@ -94,6 +99,7 @@ def main(args=None):
|
||||
except exc.CommandError as e:
|
||||
print(e.message)
|
||||
return 1
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+46
-40
@@ -15,6 +15,7 @@
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib import parse as urlparse
|
||||
|
||||
from osprofiler import _utils
|
||||
@@ -22,7 +23,7 @@ from osprofiler import _utils
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_driver(connection_string, *args, **kwargs):
|
||||
def get_driver(connection_string: str, *args: Any, **kwargs: Any) -> "Driver":
|
||||
"""Create driver's instance according to specified connection string"""
|
||||
# NOTE(ayelistratov) Backward compatibility with old Messaging notation
|
||||
# Remove after patching all OS services
|
||||
@@ -66,20 +67,25 @@ class Driver:
|
||||
and implemented by any class derived from this class.
|
||||
"""
|
||||
|
||||
default_trace_fields = {"base_id", "timestamp"}
|
||||
default_trace_fields: set[str] = {"base_id", "timestamp"}
|
||||
|
||||
def __init__(
|
||||
self, connection_str, project=None, service=None, host=None, **kwargs
|
||||
):
|
||||
self,
|
||||
connection_str: str,
|
||||
project: str | None = None,
|
||||
service: str | None = None,
|
||||
host: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.connection_str = connection_str
|
||||
self.project = project
|
||||
self.service = service
|
||||
self.host = host
|
||||
self.result = {}
|
||||
self.started_at = None
|
||||
self.finished_at = None
|
||||
self.result: dict[str, Any] = {}
|
||||
self.started_at: datetime.datetime | None = None
|
||||
self.finished_at: datetime.datetime | None = None
|
||||
# Last trace started time
|
||||
self.last_started_at = None
|
||||
self.last_started_at: datetime.datetime | None = None
|
||||
|
||||
profiler_config = kwargs.get("conf", {}).get("profiler", {})
|
||||
if hasattr(profiler_config, "filter_error_trace"):
|
||||
@@ -87,7 +93,7 @@ class Driver:
|
||||
else:
|
||||
self.filter_error_trace = False
|
||||
|
||||
def notify(self, info, **kwargs):
|
||||
def notify(self, info: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""This method will be called on each notifier.notify() call.
|
||||
|
||||
To add new drivers you should, create new subclass of this class and
|
||||
@@ -108,7 +114,7 @@ class Driver:
|
||||
"or has to be overridden"
|
||||
)
|
||||
|
||||
def get_report(self, base_id):
|
||||
def get_report(self, base_id: str) -> dict[str, Any]:
|
||||
"""Forms and returns report composed from the stored notifications.
|
||||
|
||||
:param base_id: Base id of trace elements.
|
||||
@@ -119,11 +125,13 @@ class Driver:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls):
|
||||
def get_name(cls) -> str:
|
||||
"""Returns backend specific name for the driver."""
|
||||
return cls.__name__
|
||||
|
||||
def list_traces(self, fields=None):
|
||||
def list_traces(
|
||||
self, fields: set[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Query all traces from the storage.
|
||||
|
||||
:param fields: Set of trace fields to return. Defaults to 'base_id'
|
||||
@@ -136,7 +144,7 @@ class Driver:
|
||||
"or has to be overridden"
|
||||
)
|
||||
|
||||
def list_error_traces(self):
|
||||
def list_error_traces(self) -> list[dict[str, Any]]:
|
||||
"""Query all error traces from the storage.
|
||||
|
||||
:return List of traces, where each trace is a dictionary containing
|
||||
@@ -148,7 +156,7 @@ class Driver:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_tree(nodes):
|
||||
def _build_tree(nodes: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Builds the tree (forest) data structure based on the list of nodes.
|
||||
|
||||
Tree building works in O(n*log(n)).
|
||||
@@ -161,7 +169,7 @@ class Driver:
|
||||
empty for leafs)
|
||||
"""
|
||||
|
||||
tree = []
|
||||
tree: list[dict[str, Any]] = []
|
||||
|
||||
for trace_id in nodes:
|
||||
node = nodes[trace_id]
|
||||
@@ -182,15 +190,15 @@ class Driver:
|
||||
|
||||
def _append_results(
|
||||
self,
|
||||
trace_id,
|
||||
parent_id,
|
||||
name,
|
||||
project,
|
||||
service,
|
||||
host,
|
||||
timestamp,
|
||||
raw_payload=None,
|
||||
):
|
||||
trace_id: str,
|
||||
parent_id: str,
|
||||
name: str,
|
||||
project: str | None,
|
||||
service: str | None,
|
||||
host: str | None,
|
||||
timestamp: str,
|
||||
raw_payload: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Appends the notification to the dictionary of notifications.
|
||||
|
||||
:param trace_id: UUID of current trace point
|
||||
@@ -204,9 +212,7 @@ class Driver:
|
||||
:param raw_payload: raw notification without any filtering, with all
|
||||
fields included
|
||||
"""
|
||||
timestamp = datetime.datetime.strptime(
|
||||
timestamp, "%Y-%m-%dT%H:%M:%S.%f"
|
||||
)
|
||||
ts = datetime.datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f")
|
||||
if trace_id not in self.result:
|
||||
self.result[trace_id] = {
|
||||
"info": {
|
||||
@@ -222,29 +228,29 @@ class Driver:
|
||||
self.result[trace_id]["info"][f"meta.raw_payload.{name}"] = raw_payload
|
||||
|
||||
if name.endswith("stop"):
|
||||
self.result[trace_id]["info"]["finished"] = timestamp
|
||||
self.result[trace_id]["info"]["finished"] = ts
|
||||
self.result[trace_id]["info"]["exception"] = "None"
|
||||
if raw_payload and "info" in raw_payload:
|
||||
exc = raw_payload["info"].get("etype", "None")
|
||||
self.result[trace_id]["info"]["exception"] = exc
|
||||
else:
|
||||
self.result[trace_id]["info"]["started"] = timestamp
|
||||
if not self.last_started_at or self.last_started_at < timestamp:
|
||||
self.last_started_at = timestamp
|
||||
self.result[trace_id]["info"]["started"] = ts
|
||||
if not self.last_started_at or self.last_started_at < ts:
|
||||
self.last_started_at = ts
|
||||
|
||||
if not self.started_at or self.started_at > timestamp:
|
||||
self.started_at = timestamp
|
||||
if not self.started_at or self.started_at > ts:
|
||||
self.started_at = ts
|
||||
|
||||
if not self.finished_at or self.finished_at < timestamp:
|
||||
self.finished_at = timestamp
|
||||
if not self.finished_at or self.finished_at < ts:
|
||||
self.finished_at = ts
|
||||
|
||||
def _parse_results(self):
|
||||
def _parse_results(self) -> dict[str, Any]:
|
||||
"""Parses Driver's notifications placed by _append_results() .
|
||||
|
||||
:returns: full profiling report
|
||||
"""
|
||||
|
||||
def msec(dt):
|
||||
def msec(dt: datetime.timedelta) -> int:
|
||||
# NOTE(boris-42): Unfortunately this is the simplest way that works
|
||||
# in py26 and py27
|
||||
microsec = (
|
||||
@@ -252,7 +258,7 @@ class Driver:
|
||||
)
|
||||
return int(microsec / 1000.0)
|
||||
|
||||
stats = {}
|
||||
stats: dict[str, Any] = {}
|
||||
|
||||
for r in self.result.values():
|
||||
# NOTE(boris-42): We are not able to guarantee that the backend
|
||||
@@ -282,12 +288,12 @@ class Driver:
|
||||
"name": "total",
|
||||
"started": 0,
|
||||
"finished": msec(self.finished_at - self.started_at)
|
||||
if self.started_at
|
||||
if self.started_at and self.finished_at
|
||||
else None,
|
||||
"last_trace_started": msec(
|
||||
self.last_started_at - self.started_at
|
||||
)
|
||||
if self.started_at
|
||||
if self.started_at and self.last_started_at
|
||||
else None,
|
||||
},
|
||||
"children": self._build_tree(self.result),
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any
|
||||
from urllib import parse as parser
|
||||
|
||||
from oslo_config import cfg
|
||||
@@ -24,14 +25,14 @@ from osprofiler import exc
|
||||
class ElasticsearchDriver(base.Driver):
|
||||
def __init__(
|
||||
self,
|
||||
connection_str,
|
||||
index_name="osprofiler-notifications",
|
||||
project=None,
|
||||
service=None,
|
||||
host=None,
|
||||
conf=cfg.CONF,
|
||||
**kwargs,
|
||||
):
|
||||
connection_str: str,
|
||||
index_name: str = "osprofiler-notifications",
|
||||
project: str | None = None,
|
||||
service: str | None = None,
|
||||
host: str | None = None,
|
||||
conf: cfg.ConfigOpts = cfg.CONF,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Elasticsearch driver for OSProfiler."""
|
||||
|
||||
super().__init__(
|
||||
@@ -60,10 +61,10 @@ class ElasticsearchDriver(base.Driver):
|
||||
self.index_name_error = "osprofiler-notifications-error"
|
||||
|
||||
@classmethod
|
||||
def get_name(cls):
|
||||
def get_name(cls) -> str:
|
||||
return "elasticsearch"
|
||||
|
||||
def notify(self, info):
|
||||
def notify(self, info: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Send notifications to Elasticsearch.
|
||||
|
||||
:param info: Contains information about trace element.
|
||||
@@ -80,7 +81,7 @@ class ElasticsearchDriver(base.Driver):
|
||||
info = info.copy()
|
||||
info["project"] = self.project
|
||||
info["service"] = self.service
|
||||
self.client.index(
|
||||
self.client.index( # type: ignore[call-arg]
|
||||
index=self.index_name,
|
||||
doc_type=self.conf.profiler.es_doc_type,
|
||||
body=info,
|
||||
@@ -92,22 +93,22 @@ class ElasticsearchDriver(base.Driver):
|
||||
):
|
||||
self.notify_error_trace(info)
|
||||
|
||||
def notify_error_trace(self, info):
|
||||
def notify_error_trace(self, info: dict[str, Any]) -> None:
|
||||
"""Store base_id and timestamp of error trace to a separate index."""
|
||||
self.client.index(
|
||||
self.client.index( # type: ignore[call-arg]
|
||||
index=self.index_name_error,
|
||||
doc_type=self.conf.profiler.es_doc_type,
|
||||
body={"base_id": info["base_id"], "timestamp": info["timestamp"]},
|
||||
)
|
||||
|
||||
def _hits(self, response):
|
||||
def _hits(self, response: Any) -> list[Any]:
|
||||
"""Returns all hits of search query using scrolling
|
||||
|
||||
:param response: ElasticSearch query response
|
||||
"""
|
||||
scroll_id = response["_scroll_id"]
|
||||
scroll_size = len(response["hits"]["hits"])
|
||||
result = []
|
||||
result: list[Any] = []
|
||||
|
||||
while scroll_size > 0:
|
||||
for hit in response["hits"]["hits"]:
|
||||
@@ -120,7 +121,9 @@ class ElasticsearchDriver(base.Driver):
|
||||
|
||||
return result
|
||||
|
||||
def list_traces(self, fields=None):
|
||||
def list_traces(
|
||||
self, fields: set[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Query all traces from the storage.
|
||||
|
||||
:param fields: Set of trace fields to return. Defaults to 'base_id'
|
||||
@@ -128,10 +131,10 @@ class ElasticsearchDriver(base.Driver):
|
||||
:returns: List of traces, where each trace is a dictionary containing
|
||||
at least `base_id` and `timestamp`.
|
||||
"""
|
||||
query = {"match_all": {}}
|
||||
query: dict[str, Any] = {"match_all": {}}
|
||||
fields = set(fields or self.default_trace_fields)
|
||||
|
||||
response = self.client.search(
|
||||
response = self.client.search( # type: ignore[call-arg]
|
||||
index=self.index_name,
|
||||
doc_type=self.conf.profiler.es_doc_type,
|
||||
size=self.conf.profiler.es_scroll_size,
|
||||
@@ -145,9 +148,9 @@ class ElasticsearchDriver(base.Driver):
|
||||
|
||||
return self._hits(response)
|
||||
|
||||
def list_error_traces(self):
|
||||
def list_error_traces(self) -> list[dict[str, Any]]:
|
||||
"""Returns all traces that have error/exception."""
|
||||
response = self.client.search(
|
||||
response = self.client.search( # type: ignore[call-arg]
|
||||
index=self.index_name_error,
|
||||
doc_type=self.conf.profiler.es_doc_type,
|
||||
size=self.conf.profiler.es_scroll_size,
|
||||
@@ -161,12 +164,12 @@ class ElasticsearchDriver(base.Driver):
|
||||
|
||||
return self._hits(response)
|
||||
|
||||
def get_report(self, base_id):
|
||||
def get_report(self, base_id: str) -> dict[str, Any]:
|
||||
"""Retrieves and parses notification from Elasticsearch.
|
||||
|
||||
:param base_id: Base id of trace elements.
|
||||
"""
|
||||
response = self.client.search(
|
||||
response = self.client.search( # type: ignore[call-arg]
|
||||
index=self.index_name,
|
||||
doc_type=self.conf.profiler.es_doc_type,
|
||||
size=self.conf.profiler.es_scroll_size,
|
||||
|
||||
@@ -13,6 +13,10 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from oslo_config import cfg
|
||||
|
||||
from osprofiler.drivers import base
|
||||
from osprofiler import exc
|
||||
|
||||
@@ -21,17 +25,17 @@ from osprofiler import exc
|
||||
class Jaeger(base.Driver):
|
||||
def __init__(
|
||||
self,
|
||||
connection_str,
|
||||
project=None,
|
||||
service=None,
|
||||
host=None,
|
||||
conf=None,
|
||||
**kwargs,
|
||||
):
|
||||
connection_str: str,
|
||||
project: str | None = None,
|
||||
service: str | None = None,
|
||||
host: str | None = None,
|
||||
conf: cfg.ConfigOpts | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Jaeger driver for OSProfiler."""
|
||||
|
||||
raise exc.CommandError('Jaeger driver is no longer supported')
|
||||
|
||||
@classmethod
|
||||
def get_name(cls):
|
||||
def get_name(cls) -> str:
|
||||
return "jaeger"
|
||||
|
||||
@@ -19,6 +19,7 @@ Classes to use VMware vRealize Log Insight as the trace data store.
|
||||
|
||||
import json
|
||||
import logging as log
|
||||
from typing import Any
|
||||
from urllib import parse as urlparse
|
||||
|
||||
import netaddr
|
||||
@@ -48,8 +49,13 @@ class LogInsightDriver(base.Driver):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, connection_str, project=None, service=None, host=None, **kwargs
|
||||
):
|
||||
self,
|
||||
connection_str: str,
|
||||
project: str | None = None,
|
||||
service: str | None = None,
|
||||
host: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
connection_str, project=project, service=service, host=host
|
||||
)
|
||||
@@ -73,19 +79,19 @@ class LogInsightDriver(base.Driver):
|
||||
self._client.login()
|
||||
|
||||
@classmethod
|
||||
def get_name(cls):
|
||||
def get_name(cls) -> str:
|
||||
return "loginsight"
|
||||
|
||||
def notify(self, info):
|
||||
def notify(self, info: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Send trace to Log Insight server."""
|
||||
|
||||
trace = info.copy()
|
||||
trace["project"] = self.project
|
||||
trace["service"] = self.service
|
||||
|
||||
event = {"text": "OSProfiler trace"}
|
||||
event: dict[str, Any] = {"text": "OSProfiler trace"}
|
||||
|
||||
def _create_field(name, content):
|
||||
def _create_field(name: str, content: Any) -> dict[str, Any]:
|
||||
return {"name": name, "content": content}
|
||||
|
||||
event["fields"] = [
|
||||
@@ -99,7 +105,7 @@ class LogInsightDriver(base.Driver):
|
||||
|
||||
self._client.send_event(event)
|
||||
|
||||
def get_report(self, base_id):
|
||||
def get_report(self, base_id: str) -> dict[str, Any]:
|
||||
"""Retrieves and parses trace data from Log Insight.
|
||||
|
||||
:param base_id: Trace base ID
|
||||
@@ -150,13 +156,13 @@ class LogInsightClient:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host,
|
||||
username,
|
||||
password,
|
||||
api_port=9000,
|
||||
api_ssl_port=9543,
|
||||
query_timeout=60000,
|
||||
):
|
||||
host: str,
|
||||
username: str,
|
||||
password: str,
|
||||
api_port: int = 9000,
|
||||
api_ssl_port: int = 9543,
|
||||
query_timeout: int = 60000,
|
||||
) -> None:
|
||||
self._host = host
|
||||
self._username = username
|
||||
self._password = password
|
||||
@@ -164,9 +170,9 @@ class LogInsightClient:
|
||||
self._api_ssl_port = api_ssl_port
|
||||
self._query_timeout = query_timeout
|
||||
self._session = requests.Session()
|
||||
self._session_id = None
|
||||
self._session_id: str | None = None
|
||||
|
||||
def _build_base_url(self, scheme):
|
||||
def _build_base_url(self, scheme: str) -> str:
|
||||
proto_str = f"{scheme}://"
|
||||
host_str = (
|
||||
f"[{self._host}]" if netaddr.valid_ipv6(self._host) else self._host
|
||||
@@ -176,7 +182,7 @@ class LogInsightClient:
|
||||
)
|
||||
return proto_str + host_str + port_str
|
||||
|
||||
def _check_response(self, resp):
|
||||
def _check_response(self, resp: requests.Response) -> None:
|
||||
if resp.status_code == 440:
|
||||
raise exc.LogInsightLoginTimeout()
|
||||
|
||||
@@ -189,12 +195,18 @@ class LogInsightClient:
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
msg = resp.reason
|
||||
msg = resp.reason or msg
|
||||
raise exc.LogInsightAPIError(msg)
|
||||
|
||||
def _send_request(
|
||||
self, method, scheme, path, headers=None, body=None, params=None
|
||||
):
|
||||
self,
|
||||
method: str,
|
||||
scheme: str,
|
||||
path: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
body: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
url = f"{self._build_base_url(scheme)}/{path}"
|
||||
|
||||
headers = headers or {}
|
||||
@@ -205,20 +217,21 @@ class LogInsightClient:
|
||||
req = requests.Request(
|
||||
method, url, headers=headers, data=json.dumps(body), params=params
|
||||
)
|
||||
req = req.prepare()
|
||||
resp = self._session.send(req, verify=False)
|
||||
prepped = req.prepare()
|
||||
resp = self._session.send(prepped, verify=False)
|
||||
|
||||
self._check_response(resp)
|
||||
return resp.json()
|
||||
|
||||
def _get_auth_header(self):
|
||||
return {"X-LI-Session-Id": self._session_id}
|
||||
def _get_auth_header(self) -> dict[str, str]:
|
||||
return {"X-LI-Session-Id": self._session_id or ""}
|
||||
|
||||
def _trunc_session_id(self):
|
||||
def _trunc_session_id(self) -> str | None:
|
||||
if self._session_id:
|
||||
return self._session_id[-5:]
|
||||
return None
|
||||
|
||||
def _is_current_session_active(self):
|
||||
def _is_current_session_active(self) -> bool:
|
||||
try:
|
||||
self._send_request(
|
||||
"get",
|
||||
@@ -237,7 +250,7 @@ class LogInsightClient:
|
||||
return False
|
||||
|
||||
@synchronized("li_login_lock")
|
||||
def login(self):
|
||||
def login(self) -> None:
|
||||
# Another thread might have created the session while the current
|
||||
# thread was waiting for the lock.
|
||||
if self._session_id and self._is_current_session_active():
|
||||
@@ -254,13 +267,13 @@ class LogInsightClient:
|
||||
self._session_id = resp["sessionId"]
|
||||
LOG.debug("Established session %s.", self._trunc_session_id())
|
||||
|
||||
def send_event(self, event):
|
||||
def send_event(self, event: dict[str, Any]) -> None:
|
||||
events = {"events": [event]}
|
||||
self._send_request(
|
||||
"post", "http", self.EVENTS_INGEST_PATH, body=events
|
||||
)
|
||||
|
||||
def query_events(self, params):
|
||||
def query_events(self, params: dict[str, str]) -> Any:
|
||||
# Assumes that the keys and values in the params are strings and
|
||||
# the operator is "CONTAINS".
|
||||
constraints = []
|
||||
@@ -272,7 +285,7 @@ class LogInsightClient:
|
||||
self.QUERY_EVENTS_BASE_PATH, "/".join(constraints)
|
||||
)
|
||||
|
||||
def _query_events():
|
||||
def _query_events() -> Any:
|
||||
return self._send_request(
|
||||
"get",
|
||||
"https",
|
||||
|
||||
@@ -16,7 +16,9 @@
|
||||
import functools
|
||||
import signal
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from oslo_config import cfg
|
||||
from oslo_utils import importutils
|
||||
|
||||
from osprofiler.drivers import base
|
||||
@@ -25,16 +27,16 @@ from osprofiler.drivers import base
|
||||
class Messaging(base.Driver):
|
||||
def __init__(
|
||||
self,
|
||||
connection_str,
|
||||
project=None,
|
||||
service=None,
|
||||
host=None,
|
||||
context=None,
|
||||
conf=None,
|
||||
transport_url=None,
|
||||
idle_timeout=1,
|
||||
**kwargs,
|
||||
):
|
||||
connection_str: str,
|
||||
project: str | None = None,
|
||||
service: str | None = None,
|
||||
host: str | None = None,
|
||||
context: Any = None,
|
||||
conf: cfg.ConfigOpts | None = None,
|
||||
transport_url: str | None = None,
|
||||
idle_timeout: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Driver that uses messaging as transport for notifications
|
||||
|
||||
:param connection_str: OSProfiler driver connection string,
|
||||
@@ -73,7 +75,7 @@ class Messaging(base.Driver):
|
||||
)
|
||||
conf = oslo_config.cfg.CONF
|
||||
|
||||
transport_kwargs = {}
|
||||
transport_kwargs: dict[str, Any] = {}
|
||||
if transport_url:
|
||||
transport_kwargs["url"] = transport_url
|
||||
|
||||
@@ -91,10 +93,12 @@ class Messaging(base.Driver):
|
||||
self.idle_timeout = idle_timeout
|
||||
|
||||
@classmethod
|
||||
def get_name(cls):
|
||||
def get_name(cls) -> str:
|
||||
return "messaging"
|
||||
|
||||
def notify(self, info, context=None):
|
||||
def notify(
|
||||
self, info: dict[str, Any], context: Any = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Send notifications to backend via oslo.messaging notifier API.
|
||||
|
||||
:param info: Contains information about trace element.
|
||||
@@ -118,7 +122,7 @@ class Messaging(base.Driver):
|
||||
info,
|
||||
)
|
||||
|
||||
def get_report(self, base_id):
|
||||
def get_report(self, base_id: str) -> dict[str, Any]:
|
||||
notification_endpoint = NotifyEndpoint(self.oslo_messaging, base_id)
|
||||
endpoints = [notification_endpoint]
|
||||
targets = [self.oslo_messaging.Target(topic="profiler")]
|
||||
@@ -126,7 +130,7 @@ class Messaging(base.Driver):
|
||||
self.transport, targets, endpoints, executor="threading"
|
||||
)
|
||||
|
||||
state = dict(running=False)
|
||||
state: dict[str, bool] = dict(running=False)
|
||||
sfn = functools.partial(signal_handler, state=state)
|
||||
|
||||
# modify signal handlers to handle interruption gracefully
|
||||
@@ -194,21 +198,28 @@ class Messaging(base.Driver):
|
||||
|
||||
|
||||
class NotifyEndpoint:
|
||||
def __init__(self, oslo_messaging, base_id):
|
||||
self.received_messages = []
|
||||
def __init__(self, oslo_messaging: Any, base_id: str) -> None:
|
||||
self.received_messages: list[Any] = []
|
||||
self.last_read_time = time.time()
|
||||
self.filter_rule = oslo_messaging.NotificationFilter(
|
||||
payload={"base_id": base_id}
|
||||
)
|
||||
|
||||
def info(self, ctxt, publisher_id, event_type, payload, metadata):
|
||||
def info(
|
||||
self,
|
||||
ctxt: Any,
|
||||
publisher_id: Any,
|
||||
event_type: Any,
|
||||
payload: Any,
|
||||
metadata: Any,
|
||||
) -> None:
|
||||
self.received_messages.append(payload)
|
||||
self.last_read_time = time.time()
|
||||
|
||||
def get_messages(self):
|
||||
def get_messages(self) -> list[Any]:
|
||||
return self.received_messages
|
||||
|
||||
def get_last_read_time(self):
|
||||
def get_last_read_time(self) -> float:
|
||||
return self.last_read_time # time when the latest event was received
|
||||
|
||||
|
||||
@@ -216,6 +227,6 @@ class SignalExit(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
def signal_handler(signum, frame, state):
|
||||
def signal_handler(signum: Any, frame: Any, state: dict[str, bool]) -> None:
|
||||
state["running"] = False
|
||||
raise SignalExit()
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from osprofiler.drivers import base
|
||||
from osprofiler import exc
|
||||
|
||||
@@ -20,13 +22,13 @@ from osprofiler import exc
|
||||
class MongoDB(base.Driver):
|
||||
def __init__(
|
||||
self,
|
||||
connection_str,
|
||||
db_name="osprofiler",
|
||||
project=None,
|
||||
service=None,
|
||||
host=None,
|
||||
**kwargs,
|
||||
):
|
||||
connection_str: str,
|
||||
db_name: str = "osprofiler",
|
||||
project: str | None = None,
|
||||
service: str | None = None,
|
||||
host: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""MongoDB driver for OSProfiler."""
|
||||
|
||||
super().__init__(
|
||||
@@ -45,14 +47,14 @@ class MongoDB(base.Driver):
|
||||
"To install with pip:\n `pip install pymongo`."
|
||||
)
|
||||
|
||||
client = MongoClient(self.connection_str, connect=False)
|
||||
client: Any = MongoClient(self.connection_str, connect=False)
|
||||
self.db = client[db_name]
|
||||
|
||||
@classmethod
|
||||
def get_name(cls):
|
||||
def get_name(cls) -> str:
|
||||
return "mongodb"
|
||||
|
||||
def notify(self, info):
|
||||
def notify(self, info: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Send notifications to MongoDB.
|
||||
|
||||
:param info: Contains information about trace element.
|
||||
@@ -76,7 +78,7 @@ class MongoDB(base.Driver):
|
||||
):
|
||||
self.notify_error_trace(data)
|
||||
|
||||
def notify_error_trace(self, data):
|
||||
def notify_error_trace(self, data: dict[str, Any]) -> None:
|
||||
"""Store base_id and timestamp of error trace to a separate db."""
|
||||
self.db.profiler_error.update(
|
||||
{"base_id": data["base_id"]},
|
||||
@@ -84,7 +86,9 @@ class MongoDB(base.Driver):
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
def list_traces(self, fields=None):
|
||||
def list_traces(
|
||||
self, fields: set[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Query all traces from the storage.
|
||||
|
||||
:param fields: Set of trace fields to return. Defaults to 'base_id'
|
||||
@@ -94,7 +98,7 @@ class MongoDB(base.Driver):
|
||||
"""
|
||||
fields = set(fields or self.default_trace_fields)
|
||||
ids = self.db.profiler.find({}).distinct("base_id")
|
||||
out_format = {"base_id": 1, "timestamp": 1, "_id": 0}
|
||||
out_format: dict[str, int] = {"base_id": 1, "timestamp": 1, "_id": 0}
|
||||
out_format.update({i: 1 for i in fields})
|
||||
return [
|
||||
self.db.profiler.find({"base_id": i}, out_format).sort(
|
||||
@@ -103,12 +107,12 @@ class MongoDB(base.Driver):
|
||||
for i in ids
|
||||
]
|
||||
|
||||
def list_error_traces(self):
|
||||
def list_error_traces(self) -> list[dict[str, Any]]:
|
||||
"""Returns all traces that have error/exception."""
|
||||
out_format = {"base_id": 1, "timestamp": 1, "_id": 0}
|
||||
return self.db.profiler_error.find({}, out_format)
|
||||
return list(self.db.profiler_error.find({}, out_format))
|
||||
|
||||
def get_report(self, base_id):
|
||||
def get_report(self, base_id: str) -> dict[str, Any]:
|
||||
"""Retrieves and parses notification from MongoDB.
|
||||
|
||||
:param base_id: Base id of trace elements.
|
||||
|
||||
+31
-25
@@ -13,6 +13,7 @@
|
||||
# under the License.
|
||||
|
||||
import collections
|
||||
from typing import Any
|
||||
from urllib import parse as parser
|
||||
|
||||
from oslo_config import cfg
|
||||
@@ -26,13 +27,13 @@ from osprofiler import exc
|
||||
class OTLP(base.Driver):
|
||||
def __init__(
|
||||
self,
|
||||
connection_str,
|
||||
project=None,
|
||||
service=None,
|
||||
host=None,
|
||||
conf=cfg.CONF,
|
||||
**kwargs,
|
||||
):
|
||||
connection_str: str,
|
||||
project: str | None = None,
|
||||
service: str | None = None,
|
||||
host: str | None = None,
|
||||
conf: cfg.ConfigOpts = cfg.CONF,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""OTLP driver using OTLP exporters."""
|
||||
|
||||
super().__init__(
|
||||
@@ -74,30 +75,32 @@ class OTLP(base.Driver):
|
||||
self.tracer = self.trace_api.get_tracer(__name__)
|
||||
|
||||
exporter = OTLPSpanExporter(f"{parsed_url.geturl()}/v1/traces")
|
||||
self.trace_api.get_tracer_provider().add_span_processor(
|
||||
self.trace_api.get_tracer_provider().add_span_processor( # type: ignore[attr-defined]
|
||||
BatchSpanProcessor(exporter)
|
||||
)
|
||||
|
||||
self.spans = collections.deque()
|
||||
self.spans: collections.deque[Any] = collections.deque()
|
||||
|
||||
def _get_service_name(self, conf, project, service):
|
||||
def _get_service_name(
|
||||
self, conf: cfg.ConfigOpts, project: str | None, service: str | None
|
||||
) -> str:
|
||||
prefix = conf.profiler_otlp.service_name_prefix
|
||||
if prefix:
|
||||
return f"{prefix}-{project}-{service}"
|
||||
return f"{project}-{service}"
|
||||
|
||||
@classmethod
|
||||
def get_name(cls):
|
||||
def get_name(cls) -> str:
|
||||
return "otlp"
|
||||
|
||||
def _kind(self, name):
|
||||
def _kind(self, name: str) -> Any:
|
||||
if "wsgi" in name:
|
||||
return self.trace_api.SpanKind.SERVER
|
||||
elif "db" in name or "http" in name or "api" in name:
|
||||
return self.trace_api.SpanKind.CLIENT
|
||||
return self.trace_api.SpanKind.INTERNAL
|
||||
|
||||
def _name(self, payload):
|
||||
def _name(self, payload: dict[str, Any]) -> str:
|
||||
info = payload["info"]
|
||||
if info.get("request"):
|
||||
return "WSGI_{}_{}".format(
|
||||
@@ -111,9 +114,10 @@ class OTLP(base.Driver):
|
||||
return "REQUESTS_{}_{}".format(
|
||||
info["requests"]["method"], info["requests"]["hostname"]
|
||||
)
|
||||
return payload["name"].rstrip("-start")
|
||||
return str(payload["name"]).rstrip("-start")
|
||||
|
||||
def notify(self, payload):
|
||||
def notify(self, info: dict[str, Any], **kwargs: Any) -> None:
|
||||
payload = info
|
||||
if payload["name"].endswith("start"):
|
||||
parent = self.trace_api.SpanContext(
|
||||
trace_id=utils.uuid_to_int128(payload["base_id"]),
|
||||
@@ -136,12 +140,12 @@ class OTLP(base.Driver):
|
||||
context=ctx,
|
||||
)
|
||||
|
||||
span._context = self.trace_api.SpanContext(
|
||||
trace_id=span.context.trace_id,
|
||||
span._context = self.trace_api.SpanContext( # type: ignore[attr-defined]
|
||||
trace_id=span.context.trace_id, # type: ignore[attr-defined]
|
||||
span_id=utils.shorten_id(payload["trace_id"]),
|
||||
is_remote=span.context.is_remote,
|
||||
trace_flags=span.context.trace_flags,
|
||||
trace_state=span.context.trace_state,
|
||||
is_remote=span.context.is_remote, # type: ignore[attr-defined]
|
||||
trace_flags=span.context.trace_flags, # type: ignore[attr-defined]
|
||||
trace_state=span.context.trace_state, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
self.spans.append(span)
|
||||
@@ -171,16 +175,18 @@ class OTLP(base.Driver):
|
||||
)
|
||||
span.end()
|
||||
|
||||
def get_report(self, base_id):
|
||||
def get_report(self, base_id: str) -> dict[str, Any]:
|
||||
return self._parse_results()
|
||||
|
||||
def list_traces(self, fields=None):
|
||||
def list_traces(
|
||||
self, fields: set[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
return []
|
||||
|
||||
def list_error_traces(self):
|
||||
def list_error_traces(self) -> list[dict[str, Any]]:
|
||||
return []
|
||||
|
||||
def create_span_tags(self, payload):
|
||||
def create_span_tags(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Create tags an OpenTracing compatible span.
|
||||
|
||||
:param info: Information from OSProfiler trace.
|
||||
@@ -188,7 +194,7 @@ class OTLP(base.Driver):
|
||||
from OpenTracing sematic conventions,
|
||||
and some other custom tags related to http, db calls.
|
||||
"""
|
||||
tags = {}
|
||||
tags: dict[str, Any] = {}
|
||||
info = payload["info"]
|
||||
|
||||
if info.get("db"):
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from urllib import parse as parser
|
||||
|
||||
from debtcollector import removals
|
||||
@@ -25,7 +27,7 @@ from osprofiler import exc
|
||||
|
||||
|
||||
class Redis(base.Driver):
|
||||
@removals.removed_kwarg(
|
||||
@removals.removed_kwarg( # type: ignore[untyped-decorator]
|
||||
"db",
|
||||
message="'db' parameter is deprecated "
|
||||
"and will be removed in future. "
|
||||
@@ -34,14 +36,14 @@ class Redis(base.Driver):
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
connection_str,
|
||||
db=0,
|
||||
project=None,
|
||||
service=None,
|
||||
host=None,
|
||||
conf=cfg.CONF,
|
||||
**kwargs,
|
||||
):
|
||||
connection_str: str,
|
||||
db: int = 0,
|
||||
project: str | None = None,
|
||||
service: str | None = None,
|
||||
host: str | None = None,
|
||||
conf: cfg.ConfigOpts = cfg.CONF,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Redis driver for OSProfiler."""
|
||||
|
||||
super().__init__(
|
||||
@@ -69,10 +71,10 @@ class Redis(base.Driver):
|
||||
self.namespace_error = "osprofiler_error:"
|
||||
|
||||
@classmethod
|
||||
def get_name(cls):
|
||||
def get_name(cls) -> str:
|
||||
return "redis"
|
||||
|
||||
def notify(self, info):
|
||||
def notify(self, info: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Send notifications to Redis.
|
||||
|
||||
:param info: Contains information about trace element.
|
||||
@@ -97,7 +99,7 @@ class Redis(base.Driver):
|
||||
):
|
||||
self.notify_error_trace(data)
|
||||
|
||||
def notify_error_trace(self, data):
|
||||
def notify_error_trace(self, data: dict[str, Any]) -> None:
|
||||
"""Store base_id and timestamp of error trace to a separate key."""
|
||||
key = self.namespace_error + data["base_id"]
|
||||
value = jsonutils.dumps(
|
||||
@@ -105,7 +107,9 @@ class Redis(base.Driver):
|
||||
)
|
||||
self.db.set(key, value)
|
||||
|
||||
def list_traces(self, fields=None):
|
||||
def list_traces(
|
||||
self, fields: set[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Query all traces from the storage.
|
||||
|
||||
:param fields: Set of trace fields to return. Defaults to 'base_id'
|
||||
@@ -122,7 +126,10 @@ class Redis(base.Driver):
|
||||
ids = self.db.scan_iter(match=self.namespace_opt + "*")
|
||||
for i in ids:
|
||||
# for each trace query the first event to have a timestamp
|
||||
first_event = jsonutils.loads(self.db.lindex(i, 1))
|
||||
raw = cast(bytes | None, self.db.lindex(i, 1))
|
||||
if raw is None:
|
||||
continue
|
||||
first_event = jsonutils.loads(raw)
|
||||
result.append(
|
||||
{
|
||||
key: value
|
||||
@@ -132,15 +139,19 @@ class Redis(base.Driver):
|
||||
)
|
||||
return result
|
||||
|
||||
def _list_traces_legacy(self, fields):
|
||||
def _list_traces_legacy(self, fields: set[str]) -> list[dict[str, Any]]:
|
||||
# With current schema every event is stored under its own unique key
|
||||
# To query all traces we first need to get all keys, then
|
||||
# get all events, sort them and pick up only the first one
|
||||
ids = self.db.scan_iter(match=self.namespace + "*")
|
||||
traces = [jsonutils.loads(self.db.get(i)) for i in ids]
|
||||
traces = [
|
||||
jsonutils.loads(raw)
|
||||
for i in ids
|
||||
if (raw := cast(bytes | None, self.db.get(i))) is not None
|
||||
]
|
||||
traces.sort(key=lambda x: x["timestamp"])
|
||||
seen_ids = set()
|
||||
result = []
|
||||
seen_ids: set[str] = set()
|
||||
result: list[dict[str, Any]] = []
|
||||
for trace in traces:
|
||||
if trace["base_id"] not in seen_ids:
|
||||
seen_ids.add(trace["base_id"])
|
||||
@@ -153,13 +164,17 @@ class Redis(base.Driver):
|
||||
)
|
||||
return result
|
||||
|
||||
def list_error_traces(self):
|
||||
def list_error_traces(self) -> list[dict[str, Any]]:
|
||||
"""Returns all traces that have error/exception."""
|
||||
ids = self.db.scan_iter(match=self.namespace_error + "*")
|
||||
traces = [jsonutils.loads(self.db.get(i)) for i in ids]
|
||||
traces = [
|
||||
jsonutils.loads(raw)
|
||||
for i in ids
|
||||
if (raw := cast(bytes | None, self.db.get(i))) is not None
|
||||
]
|
||||
traces.sort(key=lambda x: x["timestamp"])
|
||||
seen_ids = set()
|
||||
result = []
|
||||
seen_ids: set[str] = set()
|
||||
result: list[dict[str, Any]] = []
|
||||
for trace in traces:
|
||||
if trace["base_id"] not in seen_ids:
|
||||
seen_ids.add(trace["base_id"])
|
||||
@@ -167,19 +182,24 @@ class Redis(base.Driver):
|
||||
|
||||
return result
|
||||
|
||||
def get_report(self, base_id):
|
||||
def get_report(self, base_id: str) -> dict[str, Any]:
|
||||
"""Retrieves and parses notification from Redis.
|
||||
|
||||
:param base_id: Base id of trace elements.
|
||||
"""
|
||||
|
||||
def iterate_events():
|
||||
def iterate_events() -> Generator[bytes, None, None]:
|
||||
for key in self.db.scan_iter(
|
||||
match=self.namespace + base_id + "*"
|
||||
): # legacy
|
||||
yield self.db.get(key)
|
||||
data = cast(bytes | None, self.db.get(key))
|
||||
if data is not None:
|
||||
yield data
|
||||
|
||||
yield from self.db.lrange(self.namespace_opt + base_id, 0, -1)
|
||||
yield from cast(
|
||||
list[bytes],
|
||||
self.db.lrange(self.namespace_opt + base_id, 0, -1),
|
||||
)
|
||||
|
||||
for data in iterate_events():
|
||||
n = jsonutils.loads(data)
|
||||
@@ -199,7 +219,7 @@ class Redis(base.Driver):
|
||||
|
||||
|
||||
class RedisSentinel(Redis, base.Driver):
|
||||
@removals.removed_kwarg(
|
||||
@removals.removed_kwarg( # type: ignore[untyped-decorator]
|
||||
"db",
|
||||
message="'db' parameter is deprecated "
|
||||
"and will be removed in future. "
|
||||
@@ -208,14 +228,14 @@ class RedisSentinel(Redis, base.Driver):
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
connection_str,
|
||||
db=0,
|
||||
project=None,
|
||||
service=None,
|
||||
host=None,
|
||||
conf=cfg.CONF,
|
||||
**kwargs,
|
||||
):
|
||||
connection_str: str,
|
||||
db: int = 0,
|
||||
project: str | None = None,
|
||||
service: str | None = None,
|
||||
host: str | None = None,
|
||||
conf: cfg.ConfigOpts = cfg.CONF,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Redis driver for OSProfiler."""
|
||||
|
||||
super().__init__(
|
||||
@@ -238,16 +258,16 @@ class RedisSentinel(Redis, base.Driver):
|
||||
self.conf = conf
|
||||
socket_timeout = self.conf.profiler.socket_timeout
|
||||
parsed_url = parser.urlparse(self.connection_str)
|
||||
sentinel = Sentinel(
|
||||
[(parsed_url.hostname, int(parsed_url.port))],
|
||||
sentinel = Sentinel( # type: ignore[no-untyped-call]
|
||||
[(parsed_url.hostname, int(parsed_url.port))], # type: ignore[arg-type]
|
||||
password=parsed_url.password,
|
||||
socket_timeout=socket_timeout,
|
||||
)
|
||||
self.db = sentinel.master_for(
|
||||
self.db = sentinel.master_for( # type: ignore[no-untyped-call]
|
||||
self.conf.profiler.sentinel_service_name,
|
||||
socket_timeout=socket_timeout,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls):
|
||||
def get_name(cls) -> str:
|
||||
return "redissentinel"
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from oslo_serialization import jsonutils
|
||||
|
||||
@@ -25,14 +26,19 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
class SQLAlchemyDriver(base.Driver):
|
||||
def __init__(
|
||||
self, connection_str, project=None, service=None, host=None, **kwargs
|
||||
):
|
||||
self,
|
||||
connection_str: str,
|
||||
project: str | None = None,
|
||||
service: str | None = None,
|
||||
host: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
connection_str, project=project, service=service, host=host
|
||||
)
|
||||
|
||||
try:
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import create_engine # type: ignore[import-not-found]
|
||||
from sqlalchemy import Table, MetaData, Column
|
||||
from sqlalchemy import String, JSON, Integer
|
||||
except ImportError:
|
||||
@@ -75,10 +81,12 @@ class SQLAlchemyDriver(base.Driver):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls):
|
||||
def get_name(cls) -> str:
|
||||
return "sqlalchemy"
|
||||
|
||||
def notify(self, info, context=None):
|
||||
def notify(
|
||||
self, info: dict[str, Any], context: Any = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Write a notification the the database"""
|
||||
data = info.copy()
|
||||
base_id = data.pop("base_id", None)
|
||||
@@ -110,16 +118,19 @@ class SQLAlchemyDriver(base.Driver):
|
||||
base_id,
|
||||
)
|
||||
|
||||
def list_traces(self, fields=None):
|
||||
def list_traces(
|
||||
self, fields: set[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
try:
|
||||
from sqlalchemy.sql import select
|
||||
from sqlalchemy.sql import select # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
raise exc.CommandError(
|
||||
"To use this command, you should install 'SQLAlchemy'"
|
||||
)
|
||||
fields = set(fields or self.default_trace_fields)
|
||||
stmt = select([self._data_table])
|
||||
seen_ids = set()
|
||||
result = []
|
||||
seen_ids: set[str] = set()
|
||||
result: list[dict[str, Any]] = []
|
||||
traces = self._conn.execute(stmt).fetchall()
|
||||
for trace in traces:
|
||||
if trace["base_id"] not in seen_ids:
|
||||
@@ -133,7 +144,7 @@ class SQLAlchemyDriver(base.Driver):
|
||||
)
|
||||
return result
|
||||
|
||||
def get_report(self, base_id):
|
||||
def get_report(self, base_id: str) -> dict[str, Any]:
|
||||
try:
|
||||
from sqlalchemy.sql import select
|
||||
except ImportError:
|
||||
|
||||
+3
-3
@@ -17,11 +17,11 @@
|
||||
class CommandError(Exception):
|
||||
"""Invalid usage of CLI."""
|
||||
|
||||
def __init__(self, message=None):
|
||||
def __init__(self, message: str | None = None) -> None:
|
||||
self.message = message
|
||||
|
||||
def __str__(self):
|
||||
return self.message or self.__class__.__doc__
|
||||
def __str__(self) -> str:
|
||||
return self.message or self.__class__.__doc__ or ""
|
||||
|
||||
|
||||
class LogInsightAPIError(Exception):
|
||||
|
||||
@@ -27,8 +27,10 @@ Guidelines for writing new hacking checks
|
||||
import functools
|
||||
import re
|
||||
import tokenize
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import Any
|
||||
|
||||
from hacking import core
|
||||
from hacking import core # type: ignore[import-not-found]
|
||||
|
||||
re_no_construct_dict = re.compile(r"\sdict\(\)")
|
||||
re_no_construct_list = re.compile(r"\slist\(\)")
|
||||
@@ -47,11 +49,15 @@ re_str_format = re.compile(
|
||||
re_raises = re.compile(r"\s:raise[^s] *.*$|\s:raises *:.*$|\s:raises *[^:]+$")
|
||||
|
||||
|
||||
@core.flake8ext
|
||||
def skip_ignored_lines(func):
|
||||
@core.flake8ext # type: ignore[untyped-decorator]
|
||||
def skip_ignored_lines(
|
||||
func: Callable[..., Any],
|
||||
) -> Callable[..., Any]:
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(logical_line, filename):
|
||||
def wrapper(
|
||||
logical_line: str, filename: str
|
||||
) -> Generator[tuple[int, str], None, None]:
|
||||
line = logical_line.strip()
|
||||
if not line or line.startswith("#") or line.endswith("# noqa"):
|
||||
return
|
||||
@@ -63,7 +69,9 @@ def skip_ignored_lines(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def _parse_assert_mock_str(line):
|
||||
def _parse_assert_mock_str(
|
||||
line: str,
|
||||
) -> tuple[int, str, str] | tuple[None, None, None]:
|
||||
point = line.find(".assert_")
|
||||
|
||||
if point != -1:
|
||||
@@ -73,9 +81,11 @@ def _parse_assert_mock_str(line):
|
||||
return None, None, None
|
||||
|
||||
|
||||
@skip_ignored_lines
|
||||
@core.flake8ext
|
||||
def check_assert_methods_from_mock(logical_line, filename):
|
||||
@skip_ignored_lines # type: ignore[untyped-decorator]
|
||||
@core.flake8ext # type: ignore[untyped-decorator]
|
||||
def check_assert_methods_from_mock(
|
||||
logical_line: str, filename: str
|
||||
) -> Generator[tuple[int, str], None, None]:
|
||||
"""Ensure that ``assert_*`` methods from ``mock`` library is used correctly
|
||||
|
||||
N301 - base error number
|
||||
@@ -135,9 +145,17 @@ def check_assert_methods_from_mock(logical_line, filename):
|
||||
)
|
||||
|
||||
|
||||
@skip_ignored_lines
|
||||
@core.flake8ext
|
||||
def check_quotes(logical_line, filename):
|
||||
def _check_triple(line: str, i: int, char: str) -> bool:
|
||||
return i + 2 < len(line) and (
|
||||
char == line[i] == line[i + 1] == line[i + 2]
|
||||
)
|
||||
|
||||
|
||||
@skip_ignored_lines # type: ignore[untyped-decorator]
|
||||
@core.flake8ext # type: ignore[untyped-decorator]
|
||||
def check_quotes(
|
||||
logical_line: str, filename: str
|
||||
) -> Generator[tuple[int, str], None, None]:
|
||||
"""Check that single quotation marks are not used
|
||||
|
||||
N350
|
||||
@@ -147,11 +165,6 @@ def check_quotes(logical_line, filename):
|
||||
in_multiline_string = False
|
||||
single_quotas_are_used = False
|
||||
|
||||
def check_tripple(line, i, char):
|
||||
return i + 2 < len(line) and (
|
||||
char == line[i] == line[i + 1] == line[i + 2]
|
||||
)
|
||||
|
||||
i = 0
|
||||
while i < len(logical_line):
|
||||
char = logical_line[i]
|
||||
@@ -163,7 +176,7 @@ def check_quotes(logical_line, filename):
|
||||
i += 1 # ignore next char
|
||||
|
||||
elif in_multiline_string:
|
||||
if check_tripple(logical_line, i, "\""):
|
||||
if _check_triple(logical_line, i, "\""):
|
||||
i += 2 # skip next 2 chars
|
||||
in_multiline_string = False
|
||||
|
||||
@@ -175,7 +188,7 @@ def check_quotes(logical_line, filename):
|
||||
break
|
||||
|
||||
elif char == "\"":
|
||||
if check_tripple(logical_line, i, "\""):
|
||||
if _check_triple(logical_line, i, "\""):
|
||||
in_multiline_string = True
|
||||
i += 3
|
||||
continue
|
||||
@@ -187,9 +200,11 @@ def check_quotes(logical_line, filename):
|
||||
yield (i, "N350 Remove Single quotes")
|
||||
|
||||
|
||||
@skip_ignored_lines
|
||||
@core.flake8ext
|
||||
def check_no_constructor_data_struct(logical_line, filename):
|
||||
@skip_ignored_lines # type: ignore[untyped-decorator]
|
||||
@core.flake8ext # type: ignore[untyped-decorator]
|
||||
def check_no_constructor_data_struct(
|
||||
logical_line: str, filename: str
|
||||
) -> Generator[tuple[int, str], None, None]:
|
||||
"""Check that data structs (lists, dicts) are declared using literals
|
||||
|
||||
N351
|
||||
@@ -203,8 +218,10 @@ def check_no_constructor_data_struct(logical_line, filename):
|
||||
yield (0, "N351 Remove list() construct and use literal []")
|
||||
|
||||
|
||||
@core.flake8ext
|
||||
def check_dict_formatting_in_string(logical_line, tokens):
|
||||
@core.flake8ext # type: ignore[untyped-decorator]
|
||||
def check_dict_formatting_in_string(
|
||||
logical_line: str, tokens: Any
|
||||
) -> Generator[tuple[int, str], None, None]:
|
||||
"""Check that strings do not use dict-formatting with a single replacement
|
||||
|
||||
N352
|
||||
@@ -275,9 +292,11 @@ def check_dict_formatting_in_string(logical_line, tokens):
|
||||
current_string = ""
|
||||
|
||||
|
||||
@skip_ignored_lines
|
||||
@core.flake8ext
|
||||
def check_using_unicode(logical_line, filename):
|
||||
@skip_ignored_lines # type: ignore[untyped-decorator]
|
||||
@core.flake8ext # type: ignore[untyped-decorator]
|
||||
def check_using_unicode(
|
||||
logical_line: str, filename: str
|
||||
) -> Generator[tuple[int, str], None, None]:
|
||||
"""Check crosspython unicode usage
|
||||
|
||||
N353
|
||||
@@ -291,8 +310,8 @@ def check_using_unicode(logical_line, filename):
|
||||
)
|
||||
|
||||
|
||||
@core.flake8ext
|
||||
def check_raises(physical_line, filename):
|
||||
@core.flake8ext # type: ignore[untyped-decorator]
|
||||
def check_raises(physical_line: str, filename: str) -> tuple[int, str] | None:
|
||||
"""Check raises usage
|
||||
|
||||
N354
|
||||
@@ -309,3 +328,4 @@ def check_raises(physical_line, filename):
|
||||
"N354 ':Please use ':raises Exception: conditions' "
|
||||
"in docstrings.",
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -13,12 +13,23 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from oslo_config import cfg
|
||||
|
||||
from osprofiler import notifier
|
||||
from osprofiler import requests
|
||||
from osprofiler import web
|
||||
|
||||
|
||||
def init_from_conf(conf, context, project, service, host, **kwargs):
|
||||
def init_from_conf(
|
||||
conf: cfg.ConfigOpts,
|
||||
context: Any,
|
||||
project: str,
|
||||
service: str,
|
||||
host: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize notifier from service configuration
|
||||
|
||||
:param conf: service configuration
|
||||
|
||||
+14
-8
@@ -13,7 +13,9 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from osprofiler.drivers import base
|
||||
|
||||
@@ -21,16 +23,18 @@ from osprofiler.drivers import base
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _noop_notifier(info, context=None):
|
||||
def _noop_notifier(info: dict[str, Any], context: Any = None) -> None:
|
||||
"""Do nothing on notify()."""
|
||||
|
||||
|
||||
# NOTE(boris-42): By default we are using noop notifier.
|
||||
__notifier = _noop_notifier
|
||||
__notifier_cache = {} # map: connection-string -> notifier
|
||||
__notifier: Callable[..., None] = _noop_notifier
|
||||
__notifier_cache: dict[
|
||||
str, Callable[..., None]
|
||||
] = {} # map: connection-string -> notifier
|
||||
|
||||
|
||||
def notify(info):
|
||||
def notify(info: dict[str, Any]) -> None:
|
||||
"""Passes the profiling info to the notifier callable.
|
||||
|
||||
:param info: dictionary with profiling information
|
||||
@@ -38,12 +42,12 @@ def notify(info):
|
||||
__notifier(info)
|
||||
|
||||
|
||||
def get():
|
||||
def get() -> Callable[..., None]:
|
||||
"""Returns notifier callable."""
|
||||
return __notifier
|
||||
|
||||
|
||||
def set(notifier):
|
||||
def set(notifier: Callable[..., None]) -> None:
|
||||
"""Service that are going to use profiler should set callable notifier.
|
||||
|
||||
Callable notifier is instance of callable object, that accept exactly
|
||||
@@ -54,7 +58,9 @@ def set(notifier):
|
||||
__notifier = notifier
|
||||
|
||||
|
||||
def create(connection_string, *args, **kwargs):
|
||||
def create(
|
||||
connection_string: str, *args: Any, **kwargs: Any
|
||||
) -> Callable[..., None]:
|
||||
"""Create notifier based on specified plugin_name
|
||||
|
||||
:param connection_string: connection string which specifies the storage
|
||||
@@ -83,5 +89,5 @@ def create(connection_string, *args, **kwargs):
|
||||
return __notifier_cache[connection_string]
|
||||
|
||||
|
||||
def clear_notifier_cache():
|
||||
def clear_notifier_cache() -> None:
|
||||
__notifier_cache.clear()
|
||||
|
||||
+20
-20
@@ -247,17 +247,17 @@ cfg.CONF.register_opts(_OTLP_OPTS, group=_otlp_profiler_opt_group)
|
||||
|
||||
|
||||
def set_defaults(
|
||||
conf,
|
||||
enabled=None,
|
||||
trace_sqlalchemy=None,
|
||||
hmac_keys=None,
|
||||
connection_string=None,
|
||||
es_doc_type=None,
|
||||
es_scroll_time=None,
|
||||
es_scroll_size=None,
|
||||
socket_timeout=None,
|
||||
sentinel_service_name=None,
|
||||
):
|
||||
conf: cfg.ConfigOpts,
|
||||
enabled: bool | None = None,
|
||||
trace_sqlalchemy: bool | None = None,
|
||||
hmac_keys: str | None = None,
|
||||
connection_string: str | None = None,
|
||||
es_doc_type: str | None = None,
|
||||
es_scroll_time: str | None = None,
|
||||
es_scroll_size: int | None = None,
|
||||
socket_timeout: float | None = None,
|
||||
sentinel_service_name: str | None = None,
|
||||
) -> None:
|
||||
conf.register_opts(_PROFILER_OPTS, group=_profiler_opt_group)
|
||||
|
||||
if enabled is not None:
|
||||
@@ -308,35 +308,35 @@ def set_defaults(
|
||||
)
|
||||
|
||||
|
||||
def is_trace_enabled(conf=None):
|
||||
def is_trace_enabled(conf: cfg.ConfigOpts | None = None) -> bool:
|
||||
if conf is None:
|
||||
conf = cfg.CONF
|
||||
return conf.profiler.enabled
|
||||
return bool(conf.profiler.enabled)
|
||||
|
||||
|
||||
def is_db_trace_enabled(conf=None):
|
||||
def is_db_trace_enabled(conf: cfg.ConfigOpts | None = None) -> bool:
|
||||
if conf is None:
|
||||
conf = cfg.CONF
|
||||
return conf.profiler.enabled and conf.profiler.trace_sqlalchemy
|
||||
return bool(conf.profiler.enabled and conf.profiler.trace_sqlalchemy)
|
||||
|
||||
|
||||
def enable_web_trace(conf=None):
|
||||
def enable_web_trace(conf: cfg.ConfigOpts | None = None) -> None:
|
||||
if conf is None:
|
||||
conf = cfg.CONF
|
||||
if conf.profiler.enabled:
|
||||
web.enable(conf.profiler.hmac_keys)
|
||||
|
||||
|
||||
def disable_web_trace(conf=None):
|
||||
def disable_web_trace(conf: cfg.ConfigOpts | None = None) -> None:
|
||||
if conf is None:
|
||||
conf = cfg.CONF
|
||||
if conf.profiler.enabled:
|
||||
web.disable()
|
||||
|
||||
|
||||
def list_opts():
|
||||
def list_opts() -> list[tuple[str, list[cfg.Opt]]]:
|
||||
return [
|
||||
(_profiler_opt_group.name, _PROFILER_OPTS),
|
||||
(_jaegerprofiler_opt_group, _JAEGER_OPTS),
|
||||
(_otlp_profiler_opt_group, _OTLP_OPTS),
|
||||
(_jaegerprofiler_opt_group.name, _JAEGER_OPTS),
|
||||
(_otlp_profiler_opt_group.name, _OTLP_OPTS),
|
||||
]
|
||||
|
||||
+81
-49
@@ -14,10 +14,13 @@
|
||||
# under the License.
|
||||
|
||||
import collections
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
import inspect
|
||||
import socket
|
||||
import threading
|
||||
import types
|
||||
from typing import Any, ParamSpec, TypeVar, cast
|
||||
|
||||
from oslo_utils import reflection
|
||||
from oslo_utils import timeutils
|
||||
@@ -27,15 +30,21 @@ from osprofiler import _utils as utils
|
||||
from osprofiler import notifier
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T", bound=type)
|
||||
|
||||
# NOTE(boris-42): Thread safe storage for profiler instances.
|
||||
__local_ctx = threading.local()
|
||||
|
||||
|
||||
def clean():
|
||||
def clean() -> None:
|
||||
__local_ctx.profiler = None
|
||||
|
||||
|
||||
def _ensure_no_multiple_traced(traceable_attrs):
|
||||
def _ensure_no_multiple_traced(
|
||||
traceable_attrs: list[tuple[str, Any]],
|
||||
) -> None:
|
||||
for attr_name, attr in traceable_attrs:
|
||||
traced_times = getattr(attr, "__traced__", 0)
|
||||
if traced_times:
|
||||
@@ -46,7 +55,11 @@ def _ensure_no_multiple_traced(traceable_attrs):
|
||||
)
|
||||
|
||||
|
||||
def init(hmac_key, base_id=None, parent_id=None):
|
||||
def init(
|
||||
hmac_key: str,
|
||||
base_id: str | None = None,
|
||||
parent_id: str | None = None,
|
||||
) -> "_Profiler":
|
||||
"""Init profiler instance for current thread.
|
||||
|
||||
You should call profiler.init() before using osprofiler.
|
||||
@@ -61,10 +74,10 @@ def init(hmac_key, base_id=None, parent_id=None):
|
||||
__local_ctx.profiler = _Profiler(
|
||||
hmac_key, base_id=base_id, parent_id=parent_id
|
||||
)
|
||||
return __local_ctx.profiler
|
||||
return cast("_Profiler", __local_ctx.profiler)
|
||||
|
||||
|
||||
def get():
|
||||
def get() -> "_Profiler | None":
|
||||
"""Get profiler instance.
|
||||
|
||||
:returns: Profiler instance or None if profiler wasn't inited.
|
||||
@@ -72,7 +85,7 @@ def get():
|
||||
return getattr(__local_ctx, "profiler", None)
|
||||
|
||||
|
||||
def start(name, info=None):
|
||||
def start(name: str, info: dict[str, Any] | None = None) -> None:
|
||||
"""Send new start notification if profiler instance is presented.
|
||||
|
||||
:param name: The name of action. E.g. wsgi, rpc, db, etc..
|
||||
@@ -84,7 +97,7 @@ def start(name, info=None):
|
||||
profiler.start(name, info=info)
|
||||
|
||||
|
||||
def stop(info=None):
|
||||
def stop(info: dict[str, Any] | None = None) -> None:
|
||||
"""Send new stop notification if profiler instance is presented."""
|
||||
profiler = get()
|
||||
if profiler:
|
||||
@@ -92,12 +105,12 @@ def stop(info=None):
|
||||
|
||||
|
||||
def trace(
|
||||
name,
|
||||
info=None,
|
||||
hide_args=False,
|
||||
hide_result=True,
|
||||
allow_multiple_trace=True,
|
||||
):
|
||||
name: str,
|
||||
info: dict[str, Any] | None = None,
|
||||
hide_args: bool = False,
|
||||
hide_result: bool = True,
|
||||
allow_multiple_trace: bool = True,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""Trace decorator for functions.
|
||||
|
||||
Very useful if you would like to add trace point on existing function:
|
||||
@@ -126,7 +139,7 @@ def trace(
|
||||
info = info.copy()
|
||||
info["function"] = {}
|
||||
|
||||
def decorator(f):
|
||||
def decorator(f: Callable[P, R]) -> Callable[P, R]:
|
||||
trace_times = getattr(f, "__traced__", 0)
|
||||
if not allow_multiple_trace and trace_times:
|
||||
raise ValueError(
|
||||
@@ -134,19 +147,19 @@ def trace(
|
||||
)
|
||||
|
||||
try:
|
||||
f.__traced__ = trace_times + 1
|
||||
setattr(f, "__traced__", trace_times + 1)
|
||||
except AttributeError:
|
||||
# Tries to work around the following:
|
||||
#
|
||||
# AttributeError: 'instancemethod' object has no
|
||||
# attribute '__traced__'
|
||||
try:
|
||||
f.im_func.__traced__ = trace_times + 1
|
||||
setattr(getattr(f, "im_func"), "__traced__", trace_times + 1)
|
||||
except AttributeError: # nosec
|
||||
pass
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# NOTE(tovin07): Workaround for this issue
|
||||
# F823 local variable 'info'
|
||||
# (defined in enclosing scope on line xxx)
|
||||
@@ -161,7 +174,7 @@ def trace(
|
||||
info_["function"]["args"] = str(args)
|
||||
info_["function"]["kwargs"] = str(kwargs)
|
||||
|
||||
stop_info = None
|
||||
stop_info: dict[str, Any] | None = None
|
||||
try:
|
||||
start(name, info=info_)
|
||||
result = f(*args, **kwargs)
|
||||
@@ -184,15 +197,15 @@ def trace(
|
||||
|
||||
|
||||
def trace_cls(
|
||||
name,
|
||||
info=None,
|
||||
hide_args=False,
|
||||
hide_result=True,
|
||||
trace_private=False,
|
||||
allow_multiple_trace=True,
|
||||
trace_class_methods=False,
|
||||
trace_static_methods=False,
|
||||
):
|
||||
name: str,
|
||||
info: dict[str, Any] | None = None,
|
||||
hide_args: bool = False,
|
||||
hide_result: bool = True,
|
||||
trace_private: bool = False,
|
||||
allow_multiple_trace: bool = True,
|
||||
trace_class_methods: bool = False,
|
||||
trace_static_methods: bool = False,
|
||||
) -> Callable[[T], T]:
|
||||
"""Trace decorator for instances of class .
|
||||
|
||||
Very useful if you would like to add trace point on existing method:
|
||||
@@ -230,7 +243,9 @@ def trace_cls(
|
||||
tracing is not allowed (by default allow).
|
||||
"""
|
||||
|
||||
def trace_checker(attr_name, to_be_wrapped):
|
||||
def trace_checker(
|
||||
attr_name: str, to_be_wrapped: Any
|
||||
) -> tuple[bool, type | None]:
|
||||
if attr_name.startswith("__"):
|
||||
# Never trace really private methods.
|
||||
return (False, None)
|
||||
@@ -246,11 +261,11 @@ def trace_cls(
|
||||
return (True, classmethod)
|
||||
return (True, None)
|
||||
|
||||
def decorator(cls):
|
||||
def decorator(cls: T) -> T:
|
||||
clss = cls if inspect.isclass(cls) else cls.__class__
|
||||
mro_dicts = [c.__dict__ for c in inspect.getmro(clss)]
|
||||
traceable_attrs = []
|
||||
traceable_wrappers = []
|
||||
traceable_attrs: list[tuple[str, Any]] = []
|
||||
traceable_wrappers: list[type | None] = []
|
||||
for attr_name, attr in inspect.getmembers(cls):
|
||||
if not (inspect.ismethod(attr) or inspect.isfunction(attr)):
|
||||
continue
|
||||
@@ -305,7 +320,12 @@ class TracedMeta(type):
|
||||
traced - E.g. wsgi, rpc, db, etc...
|
||||
"""
|
||||
|
||||
def __init__(cls, cls_name, bases, attrs):
|
||||
def __init__(
|
||||
cls,
|
||||
cls_name: str,
|
||||
bases: tuple[type, ...],
|
||||
attrs: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__(cls_name, bases, attrs)
|
||||
|
||||
trace_args = dict(getattr(cls, "__trace_args__", {}))
|
||||
@@ -318,7 +338,7 @@ class TracedMeta(type):
|
||||
"e.g. __trace_args__ = {'name': 'rpc'}"
|
||||
)
|
||||
|
||||
traceable_attrs = []
|
||||
traceable_attrs: list[tuple[str, Any]] = []
|
||||
for attr_name, attr_value in attrs.items():
|
||||
if not (
|
||||
inspect.ismethod(attr_value) or inspect.isfunction(attr_value)
|
||||
@@ -340,7 +360,7 @@ class TracedMeta(type):
|
||||
|
||||
|
||||
class Trace:
|
||||
def __init__(self, name, info=None):
|
||||
def __init__(self, name: str, info: dict[str, Any] | None = None) -> None:
|
||||
"""With statement way to use profiler start()/stop().
|
||||
|
||||
|
||||
@@ -358,12 +378,17 @@ class Trace:
|
||||
self._name = name
|
||||
self._info = info
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> None:
|
||||
start(self._name, info=self._info)
|
||||
|
||||
def __exit__(self, etype, value, traceback):
|
||||
def __exit__(
|
||||
self,
|
||||
etype: type[BaseException] | None,
|
||||
value: BaseException | None,
|
||||
traceback: types.TracebackType | None,
|
||||
) -> None:
|
||||
info = None
|
||||
if etype:
|
||||
if etype and value is not None:
|
||||
info = {
|
||||
"etype": reflection.get_class_name(etype),
|
||||
"message": value.args[0] if value.args else None,
|
||||
@@ -372,15 +397,22 @@ class Trace:
|
||||
|
||||
|
||||
class _Profiler:
|
||||
def __init__(self, hmac_key, base_id=None, parent_id=None):
|
||||
def __init__(
|
||||
self,
|
||||
hmac_key: str,
|
||||
base_id: str | None = None,
|
||||
parent_id: str | None = None,
|
||||
) -> None:
|
||||
self.hmac_key = hmac_key
|
||||
if not base_id:
|
||||
base_id = str(uuidutils.generate_uuid())
|
||||
self._trace_stack = collections.deque([base_id, parent_id or base_id])
|
||||
self._name = collections.deque()
|
||||
self._host = socket.gethostname()
|
||||
self._trace_stack: collections.deque[str] = collections.deque(
|
||||
[base_id, parent_id or base_id]
|
||||
)
|
||||
self._name: collections.deque[str] = collections.deque()
|
||||
self._host: str = socket.gethostname()
|
||||
|
||||
def get_shorten_id(self, uuid_id):
|
||||
def get_shorten_id(self, uuid_id: str | int) -> str:
|
||||
"""Return shorten id of a uuid that will be used in OpenTracing drivers
|
||||
|
||||
:param uuid_id: A string of uuid that was generated by uuidutils
|
||||
@@ -388,7 +420,7 @@ class _Profiler:
|
||||
"""
|
||||
return format(utils.shorten_id(uuid_id), "x")
|
||||
|
||||
def get_base_id(self):
|
||||
def get_base_id(self) -> str:
|
||||
"""Return base id of a trace.
|
||||
|
||||
Base id is the same for all elements in one trace. It's main goal is
|
||||
@@ -396,15 +428,15 @@ class _Profiler:
|
||||
"""
|
||||
return self._trace_stack[0]
|
||||
|
||||
def get_parent_id(self):
|
||||
def get_parent_id(self) -> str:
|
||||
"""Returns parent trace element id."""
|
||||
return self._trace_stack[-2]
|
||||
|
||||
def get_id(self):
|
||||
def get_id(self) -> str:
|
||||
"""Returns current trace element id."""
|
||||
return self._trace_stack[-1]
|
||||
|
||||
def start(self, name, info=None):
|
||||
def start(self, name: str, info: dict[str, Any] | None = None) -> None:
|
||||
"""Start new event.
|
||||
|
||||
Adds new trace_id to trace stack and sends notification
|
||||
@@ -424,7 +456,7 @@ class _Profiler:
|
||||
self._trace_stack.append(str(uuidutils.generate_uuid()))
|
||||
self._notify(f"{name}-start", info)
|
||||
|
||||
def stop(self, info=None):
|
||||
def stop(self, info: dict[str, Any] | None = None) -> None:
|
||||
"""Finish latest event.
|
||||
|
||||
Same as a start, but instead of pushing trace_id to stack it pops it.
|
||||
@@ -436,8 +468,8 @@ class _Profiler:
|
||||
self._notify(f"{self._name.pop()}-stop", info)
|
||||
self._trace_stack.pop()
|
||||
|
||||
def _notify(self, name, info):
|
||||
payload = {
|
||||
def _notify(self, name: str, info: dict[str, Any]) -> None:
|
||||
payload: dict[str, Any] = {
|
||||
"name": name,
|
||||
"base_id": self.get_base_id(),
|
||||
"trace_id": self.get_id(),
|
||||
|
||||
@@ -12,7 +12,9 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from collections.abc import Callable
|
||||
import logging as log
|
||||
from typing import Any
|
||||
from urllib import parse as parser
|
||||
|
||||
from osprofiler import profiler
|
||||
@@ -24,7 +26,7 @@ from osprofiler import web
|
||||
|
||||
LOG = log.getLogger(__name__)
|
||||
|
||||
_FUNC = None
|
||||
_FUNC: Callable[..., Any] | None = None
|
||||
|
||||
try:
|
||||
from requests.adapters import HTTPAdapter
|
||||
@@ -32,11 +34,11 @@ except ImportError:
|
||||
pass
|
||||
else:
|
||||
|
||||
def send(self, request, *args, **kwargs):
|
||||
def send(self: Any, request: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
parsed_url = parser.urlparse(request.url)
|
||||
|
||||
# Best effort guessing port if needed
|
||||
port = parsed_url.port or ""
|
||||
port: int | str = parsed_url.port or ""
|
||||
if not port and parsed_url.scheme == "http":
|
||||
port = 80
|
||||
elif not port and parsed_url.scheme == "https":
|
||||
@@ -60,6 +62,8 @@ else:
|
||||
# context/span.
|
||||
request.headers.update(web.get_trace_id_headers())
|
||||
|
||||
if _FUNC is None:
|
||||
raise RuntimeError("osprofiler requests adapter not initialized")
|
||||
response = _FUNC(self, request, *args, **kwargs)
|
||||
|
||||
profiler.stop(info={"requests": {"status_code": response.status_code}})
|
||||
@@ -69,9 +73,9 @@ else:
|
||||
_FUNC = HTTPAdapter.send
|
||||
|
||||
|
||||
def enable():
|
||||
def enable() -> None:
|
||||
if _FUNC:
|
||||
HTTPAdapter.send = send
|
||||
HTTPAdapter.send = send # type: ignore[method-assign]
|
||||
LOG.debug("profiling requests enabled")
|
||||
else:
|
||||
LOG.warning(
|
||||
|
||||
@@ -13,8 +13,10 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from collections.abc import Generator
|
||||
import contextlib
|
||||
import logging as log
|
||||
from typing import Any
|
||||
|
||||
from oslo_utils import reflection
|
||||
|
||||
@@ -25,20 +27,22 @@ LOG = log.getLogger(__name__)
|
||||
_DISABLED = False
|
||||
|
||||
|
||||
def disable():
|
||||
def disable() -> None:
|
||||
"""Disable tracing of all DB queries. Reduce a lot size of profiles."""
|
||||
global _DISABLED
|
||||
_DISABLED = True
|
||||
|
||||
|
||||
def enable():
|
||||
def enable() -> None:
|
||||
"""add_tracing adds event listeners for sqlalchemy."""
|
||||
|
||||
global _DISABLED
|
||||
_DISABLED = False
|
||||
|
||||
|
||||
def add_tracing(sqlalchemy, engine, name, hide_result=True):
|
||||
def add_tracing(
|
||||
sqlalchemy: Any, engine: Any, name: str, hide_result: bool = True
|
||||
) -> None:
|
||||
"""Add tracing to all sqlalchemy calls."""
|
||||
|
||||
if not _DISABLED:
|
||||
@@ -54,7 +58,7 @@ def add_tracing(sqlalchemy, engine, name, hide_result=True):
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def wrap_session(sqlalchemy, sess):
|
||||
def wrap_session(sqlalchemy: Any, sess: Any) -> Generator[Any, None, None]:
|
||||
with sess as s:
|
||||
if not getattr(s.bind, "traced", False):
|
||||
add_tracing(sqlalchemy, s.bind, "db")
|
||||
@@ -62,17 +66,24 @@ def wrap_session(sqlalchemy, sess):
|
||||
yield s
|
||||
|
||||
|
||||
def _before_cursor_execute(name):
|
||||
def _before_cursor_execute(name: str) -> Any:
|
||||
"""Add listener that will send trace info before query is executed."""
|
||||
|
||||
def handler(conn, cursor, statement, params, context, executemany):
|
||||
def handler(
|
||||
conn: Any,
|
||||
cursor: Any,
|
||||
statement: Any,
|
||||
params: Any,
|
||||
context: Any,
|
||||
executemany: Any,
|
||||
) -> None:
|
||||
info = {"db": {"statement": statement, "params": params}}
|
||||
profiler.start(name, info=info)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _after_cursor_execute(hide_result=True):
|
||||
def _after_cursor_execute(hide_result: bool = True) -> Any:
|
||||
"""Add listener that will send trace info after query is executed.
|
||||
|
||||
:param hide_result: Boolean value to hide or show SQL result in trace.
|
||||
@@ -80,7 +91,14 @@ def _after_cursor_execute(hide_result=True):
|
||||
False - show SQL result in trace.
|
||||
"""
|
||||
|
||||
def handler(conn, cursor, statement, params, context, executemany):
|
||||
def handler(
|
||||
conn: Any,
|
||||
cursor: Any,
|
||||
statement: Any,
|
||||
params: Any,
|
||||
context: Any,
|
||||
executemany: Any,
|
||||
) -> None:
|
||||
if not hide_result:
|
||||
# Add SQL result to trace info in *-stop phase
|
||||
info = {"db": {"result": str(cursor._rows)}}
|
||||
@@ -91,7 +109,7 @@ def _after_cursor_execute(hide_result=True):
|
||||
return handler
|
||||
|
||||
|
||||
def handle_error(exception_context):
|
||||
def handle_error(exception_context: Any) -> None:
|
||||
"""Handle SQLAlchemy errors"""
|
||||
exception_class_name = reflection.get_class_name(
|
||||
exception_context.original_exception
|
||||
|
||||
@@ -86,7 +86,9 @@ class DriverTestCase(test.FunctionalTestCase):
|
||||
profiler.init("SECRET_KEY")
|
||||
|
||||
# grab base_id
|
||||
base_id = profiler.get().get_base_id()
|
||||
p = profiler.get()
|
||||
assert p is not None # noqa: S101
|
||||
base_id = p.get_base_id()
|
||||
|
||||
# execute profiled code
|
||||
foo = Foo()
|
||||
@@ -150,7 +152,9 @@ class RedisDriverTestCase(DriverTestCase):
|
||||
profiler.init("SECRET_KEY")
|
||||
|
||||
# grab base_id
|
||||
base_id = profiler.get().get_base_id()
|
||||
p = profiler.get()
|
||||
assert p is not None # noqa: S101
|
||||
base_id = p.get_base_id()
|
||||
|
||||
# execute profiled code
|
||||
foo = Foo()
|
||||
|
||||
@@ -17,6 +17,7 @@ import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import cast
|
||||
from unittest import mock
|
||||
|
||||
import ddt
|
||||
@@ -36,7 +37,8 @@ class ShellTestCase(test.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
os.environ = self.old_environment
|
||||
os.environ.clear()
|
||||
os.environ.update(self.old_environment)
|
||||
|
||||
def _trace_show_cmd(self, format_=None):
|
||||
cmd = f"trace show --connection-string redis:// {self.TRACE_ID}"
|
||||
@@ -47,7 +49,9 @@ class ShellTestCase(test.TestCase):
|
||||
def test_shell_main(self, mock_shell):
|
||||
mock_shell.side_effect = exc.CommandError("some_message")
|
||||
shell.main()
|
||||
self.assertEqual("some_message\n", sys.stdout.getvalue())
|
||||
self.assertEqual(
|
||||
"some_message\n", cast(io.StringIO, sys.stdout).getvalue()
|
||||
)
|
||||
|
||||
def run_command(self, cmd):
|
||||
shell.OSProfilerShell(cmd.split())
|
||||
@@ -117,7 +121,7 @@ class ShellTestCase(test.TestCase):
|
||||
separators=(",", ": "),
|
||||
)
|
||||
),
|
||||
sys.stdout.getvalue(),
|
||||
cast(io.StringIO, sys.stdout).getvalue(),
|
||||
)
|
||||
|
||||
@mock.patch("sys.stdout", io.StringIO())
|
||||
@@ -153,7 +157,7 @@ class ShellTestCase(test.TestCase):
|
||||
"\n".format(
|
||||
json.dumps(notifications, indent=4, separators=(",", ": "))
|
||||
),
|
||||
sys.stdout.getvalue(),
|
||||
cast(io.StringIO, sys.stdout).getvalue(),
|
||||
)
|
||||
|
||||
@mock.patch("sys.stdout", io.StringIO())
|
||||
|
||||
@@ -25,10 +25,10 @@ class NotifierBaseTestCase(test.TestCase):
|
||||
def get_name(cls):
|
||||
return "a"
|
||||
|
||||
def notify(self, a):
|
||||
def notify(self, a): # type: ignore[override]
|
||||
return a
|
||||
|
||||
self.assertEqual(10, base.get_driver("a://").notify(10))
|
||||
self.assertEqual(10, base.get_driver("a://").notify(10)) # type: ignore[arg-type, func-returns-value]
|
||||
|
||||
def test_factory_with_args(self):
|
||||
|
||||
@@ -41,10 +41,10 @@ class NotifierBaseTestCase(test.TestCase):
|
||||
def get_name(cls):
|
||||
return "b"
|
||||
|
||||
def notify(self, c):
|
||||
def notify(self, c): # type: ignore[override]
|
||||
return self.a + self.b + c
|
||||
|
||||
self.assertEqual(22, base.get_driver("b://", 5, b=7).notify(10))
|
||||
self.assertEqual(22, base.get_driver("b://", 5, b=7).notify(10)) # type: ignore[arg-type, func-returns-value]
|
||||
|
||||
def test_driver_not_found(self):
|
||||
self.assertRaises(
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
|
||||
from oslo_serialization import jsonutils
|
||||
@@ -131,7 +132,7 @@ class RedisParserTestCase(test.TestCase):
|
||||
|
||||
def test_get_report(self):
|
||||
self.redisdb.db = mock.MagicMock()
|
||||
result_elements = [
|
||||
result_elements: list[dict[str, Any]] = [
|
||||
{
|
||||
"info": {
|
||||
"project": None,
|
||||
|
||||
@@ -27,7 +27,7 @@ class InitializerTestCase(testtools.TestCase):
|
||||
conf = mock.Mock()
|
||||
conf.profiler.connection_string = "driver://"
|
||||
conf.profiler.hmac_keys = "hmac_keys"
|
||||
context = {}
|
||||
context: dict[object, object] = {}
|
||||
project = "my-project"
|
||||
service = "my-service"
|
||||
host = "my-host"
|
||||
|
||||
@@ -39,7 +39,7 @@ class NotifierTestCase(test.TestCase):
|
||||
def test_notify(self):
|
||||
m = mock.MagicMock()
|
||||
notifier.set(m)
|
||||
notifier.notify(10)
|
||||
notifier.notify(10) # type: ignore[arg-type]
|
||||
|
||||
m.assert_called_once_with(10)
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ import collections
|
||||
import copy
|
||||
import datetime
|
||||
import re
|
||||
import unittest
|
||||
from typing import Any, ClassVar
|
||||
from unittest import mock
|
||||
|
||||
|
||||
@@ -43,8 +45,8 @@ class ProfilerGlobMethodsTestCase(test.TestCase):
|
||||
|
||||
def test_start(self):
|
||||
p = profiler.init("secret", base_id="1", parent_id="2")
|
||||
p.start = mock.MagicMock()
|
||||
profiler.start("name", info="info")
|
||||
p.start = mock.MagicMock() # type: ignore[method-assign]
|
||||
profiler.start("name", info="info") # type: ignore[arg-type]
|
||||
p.start.assert_called_once_with("name", info="info")
|
||||
|
||||
def test_stop_not_inited(self):
|
||||
@@ -53,8 +55,8 @@ class ProfilerGlobMethodsTestCase(test.TestCase):
|
||||
|
||||
def test_stop(self):
|
||||
p = profiler.init("secret", base_id="1", parent_id="2")
|
||||
p.stop = mock.MagicMock()
|
||||
profiler.stop(info="info")
|
||||
p.stop = mock.MagicMock() # type: ignore[method-assign]
|
||||
profiler.stop(info="info") # type: ignore[arg-type]
|
||||
p.stop.assert_called_once_with(info="info")
|
||||
|
||||
|
||||
@@ -159,10 +161,10 @@ class WithTraceTestCase(test.TestCase):
|
||||
@mock.patch("osprofiler.profiler.start")
|
||||
def test_with_trace(self, mock_start, mock_stop):
|
||||
|
||||
with profiler.Trace("a", info="a1"):
|
||||
with profiler.Trace("a", info="a1"): # type: ignore[arg-type]
|
||||
mock_start.assert_called_once_with("a", info="a1")
|
||||
mock_start.reset_mock()
|
||||
with profiler.Trace("b", info="b1"):
|
||||
with profiler.Trace("b", info="b1"): # type: ignore[arg-type]
|
||||
mock_start.assert_called_once_with("b", info="b1")
|
||||
mock_stop.assert_called_once_with(info=None)
|
||||
mock_stop.reset_mock()
|
||||
@@ -452,7 +454,7 @@ class TraceClsDecoratorTestCase(test.TestCase):
|
||||
|
||||
@mock.patch("osprofiler.profiler.stop")
|
||||
@mock.patch("osprofiler.profiler.start")
|
||||
@test.testcase.skip(
|
||||
@unittest.skip(
|
||||
"Static method tracing was disabled due the bug. This test should be "
|
||||
"skipped until we find the way to address it."
|
||||
)
|
||||
@@ -499,7 +501,10 @@ class TraceClsDecoratorTestCase(test.TestCase):
|
||||
|
||||
|
||||
class FakeTraceWithMetaclassBase(metaclass=profiler.TracedMeta):
|
||||
__trace_args__ = {"name": "rpc", "info": {"a": 10}}
|
||||
__trace_args__: ClassVar[dict[str, Any]] = {
|
||||
"name": "rpc",
|
||||
"info": {"a": 10},
|
||||
}
|
||||
|
||||
def method1(self, a, b, c=10):
|
||||
return a + b + c
|
||||
@@ -520,14 +525,21 @@ class FakeTraceDummy(FakeTraceWithMetaclassBase):
|
||||
|
||||
|
||||
class FakeTraceWithMetaclassHideArgs(FakeTraceWithMetaclassBase):
|
||||
__trace_args__ = {"name": "a", "info": {"b": 20}, "hide_args": True}
|
||||
__trace_args__: ClassVar[dict[str, Any]] = {
|
||||
"name": "a",
|
||||
"info": {"b": 20},
|
||||
"hide_args": True,
|
||||
}
|
||||
|
||||
def method5(self, k, l):
|
||||
return k + l
|
||||
|
||||
|
||||
class FakeTraceWithMetaclassPrivate(FakeTraceWithMetaclassBase):
|
||||
__trace_args__ = {"name": "rpc", "trace_private": True}
|
||||
__trace_args__: ClassVar[dict[str, Any]] = {
|
||||
"name": "rpc",
|
||||
"trace_private": True,
|
||||
}
|
||||
|
||||
def _new_private_method(self, m):
|
||||
return 2 * m
|
||||
|
||||
@@ -68,6 +68,7 @@ class UtilsTestCase(test.TestCase):
|
||||
|
||||
process_data = utils.signed_unpack(packed_data, hmac_data, [hmac])
|
||||
self.assertIn("hmac_key", process_data)
|
||||
assert process_data is not None # noqa: S101
|
||||
process_data.pop("hmac_key")
|
||||
self.assertEqual(data, process_data)
|
||||
|
||||
@@ -77,6 +78,7 @@ class UtilsTestCase(test.TestCase):
|
||||
packed_data, hmac_data = utils.signed_pack(data, keys[-1])
|
||||
|
||||
process_data = utils.signed_unpack(packed_data, hmac_data, keys)
|
||||
assert process_data is not None # noqa: S101
|
||||
self.assertEqual(keys[-1], process_data["hmac_key"])
|
||||
|
||||
def test_signed_pack_unpack_many_wrong_keys(self):
|
||||
|
||||
@@ -35,7 +35,7 @@ class WebTestCase(test.TestCase):
|
||||
self.addCleanup(profiler.clean)
|
||||
|
||||
def test_get_trace_id_headers_no_hmac(self):
|
||||
profiler.init(None, base_id="y", parent_id="z")
|
||||
profiler.init(None, base_id="y", parent_id="z") # type: ignore[arg-type]
|
||||
headers = web.get_trace_id_headers()
|
||||
self.assertEqual(headers, {})
|
||||
|
||||
@@ -50,6 +50,7 @@ class WebTestCase(test.TestCase):
|
||||
headers["X-Trace-Info"], headers["X-Trace-HMAC"], ["key"]
|
||||
)
|
||||
self.assertIn("hmac_key", trace_info)
|
||||
assert trace_info is not None # noqa: S101
|
||||
self.assertEqual("key", trace_info.pop("hmac_key"))
|
||||
self.assertEqual({"parent_id": "z", "base_id": "y"}, trace_info)
|
||||
|
||||
@@ -91,9 +92,9 @@ class WebMiddlewareTestCase(test.TestCase):
|
||||
request.get_response.return_value = "yeah!"
|
||||
request.headers = headers
|
||||
|
||||
middleware = web.WsgiMiddleware("app", hmac_key, enabled=enabled)
|
||||
middleware = web.WsgiMiddleware(mock.ANY, hmac_key, enabled=enabled)
|
||||
self.assertEqual("yeah!", middleware(request))
|
||||
request.get_response.assert_called_once_with("app")
|
||||
request.get_response.assert_called_once_with(mock.ANY)
|
||||
self.assertEqual(0, mock_profiler_init.call_count)
|
||||
|
||||
@mock.patch("osprofiler.web.profiler.init")
|
||||
@@ -157,7 +158,8 @@ class WebMiddlewareTestCase(test.TestCase):
|
||||
def test_wsgi_middleware_invalid_trace_info(self, mock_profiler_init):
|
||||
hmac_key = "secret"
|
||||
pack = utils.signed_pack(
|
||||
[{"base_id": "1"}, {"parent_id": "2"}], hmac_key
|
||||
[{"base_id": "1"}, {"parent_id": "2"}], # type: ignore[arg-type]
|
||||
hmac_key,
|
||||
)
|
||||
headers = {
|
||||
"a": "1",
|
||||
@@ -191,7 +193,7 @@ class WebMiddlewareTestCase(test.TestCase):
|
||||
}
|
||||
|
||||
middleware = web.WsgiMiddleware(
|
||||
"app", f"secret1,{hmac_key}", enabled=True
|
||||
mock.ANY, f"secret1,{hmac_key}", enabled=True
|
||||
)
|
||||
self.assertEqual("yeah!", middleware(request))
|
||||
mock_profiler_init.assert_called_once_with(
|
||||
@@ -220,7 +222,7 @@ class WebMiddlewareTestCase(test.TestCase):
|
||||
}
|
||||
|
||||
middleware = web.WsgiMiddleware(
|
||||
"app", f"{hmac_key},secret2", enabled=True
|
||||
mock.ANY, f"{hmac_key},secret2", enabled=True
|
||||
)
|
||||
self.assertEqual("yeah!", middleware(request))
|
||||
mock_profiler_init.assert_called_once_with(
|
||||
@@ -249,7 +251,7 @@ class WebMiddlewareTestCase(test.TestCase):
|
||||
"X-Trace-HMAC": pack[1],
|
||||
}
|
||||
|
||||
middleware = web.WsgiMiddleware("app", hmac_key, enabled=True)
|
||||
middleware = web.WsgiMiddleware(mock.ANY, hmac_key, enabled=True)
|
||||
self.assertEqual("yeah!", middleware(request))
|
||||
mock_profiler_init.assert_called_once_with(
|
||||
hmac_key=hmac_key, base_id="1", parent_id="2"
|
||||
@@ -269,7 +271,7 @@ class WebMiddlewareTestCase(test.TestCase):
|
||||
request = mock.MagicMock()
|
||||
request.get_response.return_value = "yeah!"
|
||||
web.disable()
|
||||
middleware = web.WsgiMiddleware("app", "hmac_key", enabled=True)
|
||||
middleware = web.WsgiMiddleware(mock.ANY, "hmac_key", enabled=True)
|
||||
self.assertEqual("yeah!", middleware(request))
|
||||
self.assertEqual(mock_profiler_init.call_count, 0)
|
||||
|
||||
@@ -294,7 +296,7 @@ class WebMiddlewareTestCase(test.TestCase):
|
||||
}
|
||||
|
||||
web.enable("super_secret_key1,super_secret_key2")
|
||||
middleware = web.WsgiMiddleware("app", enabled=True)
|
||||
middleware = web.WsgiMiddleware(mock.ANY, enabled=True)
|
||||
self.assertEqual("yeah!", middleware(request))
|
||||
mock_profiler_init.assert_called_once_with(
|
||||
hmac_key=hmac_key, base_id="1", parent_id="2"
|
||||
|
||||
+35
-12
@@ -13,11 +13,18 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TypeGuard, TYPE_CHECKING
|
||||
|
||||
import webob.dec
|
||||
|
||||
from osprofiler import _utils as utils
|
||||
from osprofiler import profiler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed.wsgi import WSGIApplication
|
||||
|
||||
|
||||
# Trace keys that are required or optional, any other
|
||||
# keys that are present will cause the trace to be rejected...
|
||||
@@ -31,21 +38,21 @@ X_TRACE_INFO = "X-Trace-Info"
|
||||
X_TRACE_HMAC = "X-Trace-HMAC"
|
||||
|
||||
|
||||
def get_trace_id_headers():
|
||||
def get_trace_id_headers() -> dict[str, str]:
|
||||
"""Adds the trace id headers (and any hmac) into provided dictionary."""
|
||||
p = profiler.get()
|
||||
if p and p.hmac_key:
|
||||
data = {"base_id": p.get_base_id(), "parent_id": p.get_id()}
|
||||
pack = utils.signed_pack(data, p.hmac_key)
|
||||
return {X_TRACE_INFO: pack[0], X_TRACE_HMAC: pack[1]}
|
||||
return {X_TRACE_INFO: pack[0].decode(), X_TRACE_HMAC: pack[1] or ""}
|
||||
return {}
|
||||
|
||||
|
||||
_ENABLED = None
|
||||
_HMAC_KEYS = None
|
||||
_ENABLED: bool | None = None
|
||||
_HMAC_KEYS: list[str] | tuple[str, ...] | None = None
|
||||
|
||||
|
||||
def disable():
|
||||
def disable() -> None:
|
||||
"""Disable middleware.
|
||||
|
||||
This is the alternative way to disable middleware. It will be used to be
|
||||
@@ -55,7 +62,7 @@ def disable():
|
||||
_ENABLED = False
|
||||
|
||||
|
||||
def enable(hmac_keys=None):
|
||||
def enable(hmac_keys: str | None = None) -> None:
|
||||
"""Enable middleware."""
|
||||
global _ENABLED, _HMAC_KEYS
|
||||
_ENABLED = True
|
||||
@@ -65,7 +72,13 @@ def enable(hmac_keys=None):
|
||||
class WsgiMiddleware:
|
||||
"""WSGI Middleware that enables tracing for an application."""
|
||||
|
||||
def __init__(self, application, hmac_keys=None, enabled=False, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
application: WSGIApplication,
|
||||
hmac_keys: str | None = None,
|
||||
enabled: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize middleware with api-paste.ini arguments.
|
||||
|
||||
:application: wsgi app
|
||||
@@ -86,13 +99,17 @@ class WsgiMiddleware:
|
||||
self.hmac_keys = utils.split(hmac_keys or "")
|
||||
|
||||
@classmethod
|
||||
def factory(cls, global_conf, **local_conf):
|
||||
def filter_(app):
|
||||
def factory(
|
||||
cls, global_conf: dict[str, Any] | None, **local_conf: Any
|
||||
) -> Any:
|
||||
def filter_(app: Any) -> WsgiMiddleware:
|
||||
return cls(app, **local_conf)
|
||||
|
||||
return filter_
|
||||
|
||||
def _trace_is_valid(self, trace_info):
|
||||
def _trace_is_valid(
|
||||
self, trace_info: dict[str, Any] | None
|
||||
) -> TypeGuard[dict[str, Any]]:
|
||||
if not isinstance(trace_info, dict):
|
||||
return False
|
||||
trace_keys = set(trace_info.keys())
|
||||
@@ -103,7 +120,9 @@ class WsgiMiddleware:
|
||||
return True
|
||||
|
||||
@webob.dec.wsgify
|
||||
def __call__(self, request):
|
||||
def __call__(
|
||||
self, request: webob.request.Request
|
||||
) -> webob.response.Response:
|
||||
if (
|
||||
_ENABLED is not None
|
||||
and not _ENABLED
|
||||
@@ -121,7 +140,11 @@ class WsgiMiddleware:
|
||||
if not self._trace_is_valid(trace_info):
|
||||
return request.get_response(self.application)
|
||||
|
||||
profiler.init(**trace_info)
|
||||
profiler.init(
|
||||
hmac_key=trace_info["hmac_key"],
|
||||
base_id=trace_info.get("base_id"),
|
||||
parent_id=trace_info.get("parent_id"),
|
||||
)
|
||||
info = {
|
||||
"request": {
|
||||
"path": request.path,
|
||||
|
||||
@@ -65,6 +65,21 @@ packages = [
|
||||
"osprofiler"
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
show_column_numbers = true
|
||||
show_error_context = true
|
||||
strict = true
|
||||
disable_error_code = ["import-untyped"]
|
||||
exclude = "(?x)(doc | releasenotes)"
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["osprofiler.tests.*"]
|
||||
disallow_untyped_calls = false
|
||||
disallow_untyped_defs = false
|
||||
disallow_subclassing_any = false
|
||||
disable_error_code = ["import-untyped"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 79
|
||||
|
||||
|
||||
@@ -24,10 +24,31 @@ deps =
|
||||
oslo.messaging
|
||||
|
||||
[testenv:pep8]
|
||||
description =
|
||||
Run style checks.
|
||||
deps =
|
||||
pre-commit
|
||||
{[testenv:mypy]deps}
|
||||
commands =
|
||||
pre-commit run -a
|
||||
{[testenv:mypy]commands}
|
||||
|
||||
[testenv:mypy]
|
||||
description =
|
||||
Run type checks.
|
||||
deps =
|
||||
{[testenv]deps}
|
||||
mypy
|
||||
types-PySocks
|
||||
types-WebOb
|
||||
types-netaddr
|
||||
types-protobuf
|
||||
types-psutil
|
||||
types-python-dateutil
|
||||
types-requests
|
||||
types-setuptools
|
||||
commands =
|
||||
mypy --cache-dir="{envdir}/mypy_cache" {posargs:osprofiler}
|
||||
|
||||
[testenv:venv]
|
||||
commands = {posargs}
|
||||
|
||||
Reference in New Issue
Block a user