From 0c58afc7ee23fd5e5cca19318e3f055dcc79e021 Mon Sep 17 00:00:00 2001
From: Shreeya Deshpande <shreeyad@nvidia.com>
Date: Sun, 12 May 2024 15:38:07 -0700
Subject: [PATCH] Add get_statsd_client function

It takes a config dict and some caller-specific options, similar to
get_logger. Use this in get_logger, so our logging module doesn't
need to know anything about statsd config options.

Co-Authored-By: yanxiao@nvidia.com
Change-Id: I5ae2cc5c257fb8d7eab885977d9d9cf602224ec7
---
 swift/common/statsd_client.py          | 39 +++++++++++++++++-
 swift/common/utils/logs.py             | 17 +-------
 test/debug_logger.py                   |  7 +---
 test/unit/common/test_statsd_client.py | 57 ++++++++++++++++++++++++++
 test/unit/proxy/test_server.py         |  9 ++--
 5 files changed, 103 insertions(+), 26 deletions(-)

diff --git a/swift/common/statsd_client.py b/swift/common/statsd_client.py
index 492d264bbf..825bede383 100644
--- a/swift/common/statsd_client.py
+++ b/swift/common/statsd_client.py
@@ -24,21 +24,55 @@ from eventlet.green import socket
 import six
 
 
+def get_statsd_client(conf=None, tail_prefix='', logger=None):
+    """
+    Get an instance of StatsdClient using config settings.
+
+    **config and defaults**::
+
+        log_statsd_host = (disabled)
+        log_statsd_port = 8125
+        log_statsd_default_sample_rate = 1.0
+        log_statsd_sample_rate_factor = 1.0
+        log_statsd_metric_prefix = (empty-string)
+
+    :param conf: Configuration dict to read settings from
+    :param tail_prefix: tail prefix to pass to statsd client
+    :param logger: stdlib logger instance used by statsd client for logging
+    :return: an instance of ``StatsdClient``
+
+    """
+    conf = conf or {}
+
+    host = conf.get('log_statsd_host')
+    port = int(conf.get('log_statsd_port', 8125))
+    base_prefix = conf.get('log_statsd_metric_prefix', '')
+    default_sample_rate = float(
+        conf.get('log_statsd_default_sample_rate', 1))
+    sample_rate_factor = float(
+        conf.get('log_statsd_sample_rate_factor', 1))
+
+    return StatsdClient(host, port, base_prefix=base_prefix,
+                        tail_prefix=tail_prefix,
+                        default_sample_rate=default_sample_rate,
+                        sample_rate_factor=sample_rate_factor, logger=logger)
+
+
 class StatsdClient(object):
     def __init__(self, host, port, base_prefix='', tail_prefix='',
                  default_sample_rate=1, sample_rate_factor=1, logger=None):
         self._host = host
         self._port = port
         self._base_prefix = base_prefix
-        self._set_prefix(tail_prefix)
         self._default_sample_rate = default_sample_rate
         self._sample_rate_factor = sample_rate_factor
         self.random = random
         self.logger = logger
+        self._set_prefix(tail_prefix)
         self._sock_family = self._target = None
 
         if self._host:
-            self._set_sock_family_and_target(host, port)
+            self._set_sock_family_and_target(self._host, self._port)
 
     def _set_sock_family_and_target(self, host, port):
         # Determine if host is IPv4 or IPv6
@@ -123,6 +157,7 @@ class StatsdClient(object):
         if sample_rate is None:
             sample_rate = self._default_sample_rate
         sample_rate = sample_rate * self._sample_rate_factor
+
         parts = ['%s%s:%s' % (self._prefix, m_name, m_value), m_type]
         if sample_rate < 1:
             if self.random() < sample_rate:
diff --git a/swift/common/utils/logs.py b/swift/common/utils/logs.py
index 35c161fb6f..54780f10df 100644
--- a/swift/common/utils/logs.py
+++ b/swift/common/utils/logs.py
@@ -644,11 +644,6 @@ def get_logger(conf, name=None, log_to_console=False, log_route=None,
         log_udp_host = (disabled)
         log_udp_port = logging.handlers.SYSLOG_UDP_PORT
         log_address = /dev/log
-        log_statsd_host = (disabled)
-        log_statsd_port = 8125
-        log_statsd_default_sample_rate = 1.0
-        log_statsd_sample_rate_factor = 1.0
-        log_statsd_metric_prefix = (empty-string)
 
     :param conf: Configuration dict to read settings from
     :param name: This value is used to populate the ``server`` field in the log
@@ -734,18 +729,10 @@ def get_logger(conf, name=None, log_to_console=False, log_route=None,
         getattr(logging, conf.get('log_level', 'INFO').upper(), logging.INFO))
 
     # Setup logger with a StatsD client if so configured
-    statsd_host = conf.get('log_statsd_host')
-    statsd_port = int(conf.get('log_statsd_port', 8125))
-    base_prefix = conf.get('log_statsd_metric_prefix', '')
-    default_sample_rate = float(conf.get(
-        'log_statsd_default_sample_rate', 1))
-    sample_rate_factor = float(conf.get(
-        'log_statsd_sample_rate_factor', 1))
     if statsd_tail_prefix is None:
         statsd_tail_prefix = name
-    logger.statsd_client = statsd_client.StatsdClient(
-        statsd_host, statsd_port, base_prefix, statsd_tail_prefix,
-        default_sample_rate, sample_rate_factor, logger=logger)
+    logger.statsd_client = statsd_client.get_statsd_client(
+        conf, statsd_tail_prefix, logger)
 
     adapted_logger = LogAdapter(logger, name)
     other_handlers = conf.get('log_custom_handlers', None)
diff --git a/test/debug_logger.py b/test/debug_logger.py
index b59dfe9e70..ce17650898 100644
--- a/test/debug_logger.py
+++ b/test/debug_logger.py
@@ -30,11 +30,8 @@ class WARN_DEPRECATED(Exception):
 
 
 class FakeStatsdClient(statsd_client.StatsdClient):
-    def __init__(self, host, port, base_prefix='', tail_prefix='',
-                 default_sample_rate=1, sample_rate_factor=1, logger=None):
-        super(FakeStatsdClient, self).__init__(
-            host, port, base_prefix, tail_prefix, default_sample_rate,
-            sample_rate_factor, logger)
+    def __init__(self, *args, **kwargs):
+        super(FakeStatsdClient, self).__init__(*args, **kwargs)
         self.clear()
 
         # Capture then call parent pubic stat functions
diff --git a/test/unit/common/test_statsd_client.py b/test/unit/common/test_statsd_client.py
index 12329f2118..73ea86d41f 100644
--- a/test/unit/common/test_statsd_client.py
+++ b/test/unit/common/test_statsd_client.py
@@ -82,6 +82,9 @@ class TestStatsdClient(BaseTestStasdClient):
     def test_init_host(self):
         client = StatsdClient('myhost', 1234)
         self.assertEqual([('myhost', 1234)], self.getaddrinfo_calls)
+        client1 = statsd_client.get_statsd_client(
+            conf={'log_statsd_host': 'myhost1',
+                  'log_statsd_port': 1235})
         with mock.patch.object(client, '_open_socket') as mock_open:
             self.assertIs(client.increment('tunafish'),
                           mock_open.return_value.sendto.return_value)
@@ -90,14 +93,28 @@ class TestStatsdClient(BaseTestStasdClient):
             mock.call().sendto(b'tunafish:1|c', ('myhost', 1234)),
             mock.call().close(),
         ])
+        with mock.patch.object(client1, '_open_socket') as mock_open1:
+            self.assertIs(client1.increment('tunafish'),
+                          mock_open1.return_value.sendto.return_value)
+        self.assertEqual(mock_open1.mock_calls, [
+            mock.call(),
+            mock.call().sendto(b'tunafish:1|c', ('myhost1', 1235)),
+            mock.call().close(),
+        ])
 
     def test_init_host_is_none(self):
         client = StatsdClient(None, None)
+        client1 = statsd_client.get_statsd_client(conf=None,
+                                                  logger=None)
         self.assertIsNone(client._host)
+        self.assertIsNone(client1._host)
         self.assertFalse(self.getaddrinfo_calls)
         with mock.patch.object(client, '_open_socket') as mock_open:
             self.assertIsNone(client.increment('tunafish'))
         self.assertFalse(mock_open.mock_calls)
+        with mock.patch.object(client1, '_open_socket') as mock_open1:
+            self.assertIsNone(client1.increment('tunafish'))
+        self.assertFalse(mock_open1.mock_calls)
         self.assertFalse(self.getaddrinfo_calls)
 
 
@@ -762,3 +779,43 @@ class TestStatsdLoggingDelegation(unittest.TestCase):
         self.assertStat('alpha.beta.another.counter:3|c|@0.9912',
                         self.logger.update_stats, 'another.counter', 3,
                         sample_rate=0.9912)
+
+
+class TestModuleFunctions(unittest.TestCase):
+    def setUp(self):
+        self.logger = debug_logger()
+
+    def test_get_statsd_client_defaults(self):
+        # no options configured
+        client = statsd_client.get_statsd_client({})
+        self.assertIsInstance(client, StatsdClient)
+        self.assertIsNone(client._host)
+        self.assertEqual(8125, client._port)
+        self.assertEqual('', client._base_prefix)
+        self.assertEqual('', client._prefix)
+        self.assertEqual(1.0, client._default_sample_rate)
+        self.assertEqual(1.0, client._sample_rate_factor)
+        self.assertIsNone(client.logger)
+
+    def test_get_statsd_client_options(self):
+        # legacy options...
+        conf = {
+            'log_statsd_host': 'example.com',
+            'log_statsd_port': '6789',
+            'log_statsd_metric_prefix': 'banana',
+            'log_statsd_default_sample_rate': '3.3',
+            'log_statsd_sample_rate_factor': '4.4',
+            'log_junk': 'ignored',
+        }
+        client = statsd_client.get_statsd_client(
+            conf, tail_prefix='milkshake', logger=self.logger)
+        self.assertIsInstance(client, StatsdClient)
+        self.assertEqual('example.com', client._host)
+        self.assertEqual(6789, client._port)
+        self.assertEqual('banana', client._base_prefix)
+        self.assertEqual('banana.milkshake.', client._prefix)
+        self.assertEqual(3.3, client._default_sample_rate)
+        self.assertEqual(4.4, client._sample_rate_factor)
+        self.assertEqual(self.logger, client.logger)
+        warn_lines = self.logger.get_lines_for_level('warning')
+        self.assertEqual([], warn_lines)
diff --git a/test/unit/proxy/test_server.py b/test/unit/proxy/test_server.py
index 1ecacace02..46cc3f4be3 100644
--- a/test/unit/proxy/test_server.py
+++ b/test/unit/proxy/test_server.py
@@ -2292,15 +2292,15 @@ class TestProxyServerConfigLoading(unittest.TestCase):
         use = egg:swift#proxy
         """ % self.tempdir
         conf_path = self._write_conf(dedent(conf_sections))
-
-        with mock.patch('swift.common.statsd_client.StatsdClient')\
+        with mock.patch('swift.common.statsd_client.StatsdClient') \
                 as mock_statsd:
             app = loadapp(conf_path, allow_modify_pipeline=False)
         # logger name is hard-wired 'proxy-server'
         self.assertEqual('proxy-server', app.logger.name)
         self.assertEqual('swift', app.logger.server)
         mock_statsd.assert_called_once_with(
-            'example.com', 8125, '', 'proxy-server', 1.0, 1.0,
+            'example.com', 8125, base_prefix='', tail_prefix='proxy-server',
+            default_sample_rate=1.0, sample_rate_factor=1.0,
             logger=app.logger.logger)
 
         conf_sections = """
@@ -2326,7 +2326,8 @@ class TestProxyServerConfigLoading(unittest.TestCase):
         self.assertEqual('test-name', app.logger.server)
         # statsd tail prefix is hard-wired 'proxy-server'
         mock_statsd.assert_called_once_with(
-            'example.com', 8125, '', 'proxy-server', 1.0, 1.0,
+            'example.com', 8125, base_prefix='', tail_prefix='proxy-server',
+            default_sample_rate=1.0, sample_rate_factor=1.0,
             logger=app.logger.logger)