From c7e74f307c30b67daecf7cc2774c242968dd24b3 Mon Sep 17 00:00:00 2001 From: Stefano Rivera Date: Sun, 26 Apr 2026 15:31:27 -0400 Subject: [PATCH 1/2] Implement h_count in Histogram directly Rather than relying on MultiProcessCollector to accummulate it Signed-off-by: Stefano Rivera --- prometheus_client/metrics.py | 2 ++ prometheus_client/multiprocess.py | 2 -- tests/test_multiprocess.py | 7 ++++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/prometheus_client/metrics.py b/prometheus_client/metrics.py index 4c53b26b..7af3e948 100644 --- a/prometheus_client/metrics.py +++ b/prometheus_client/metrics.py @@ -652,6 +652,7 @@ def _metric_init(self) -> None: self._created = time.time() bucket_labelnames = self._labelnames + ('le',) self._sum = values.ValueClass(self._type, self._name, self._name + '_sum', self._labelnames, self._labelvalues, self._documentation) + self._count = values.ValueClass(self._type, self._name, self._name + '_count', self._labelnames, self._labelvalues, self._documentation) for b in self._upper_bounds: self._buckets.append(values.ValueClass( self._type, @@ -674,6 +675,7 @@ def observe(self, amount: float, exemplar: Optional[Dict[str, str]] = None) -> N """ self._raise_if_not_observable() self._sum.inc(amount) + self._count.inc(1) for i, bound in enumerate(self._upper_bounds): if amount <= bound: self._buckets[i].inc(1) diff --git a/prometheus_client/multiprocess.py b/prometheus_client/multiprocess.py index db55874e..38849ee2 100644 --- a/prometheus_client/multiprocess.py +++ b/prometheus_client/multiprocess.py @@ -156,8 +156,6 @@ def _accumulate_metrics(metrics, accumulate): samples[labels][sample_key] = acc else: samples[labels][sample_key] = value - if accumulate: - samples[labels][(metric.name + '_count', labels)] = acc # Convert to correct sample format. metric.samples = [] diff --git a/tests/test_multiprocess.py b/tests/test_multiprocess.py index ee0c7423..a04f0274 100644 --- a/tests/test_multiprocess.py +++ b/tests/test_multiprocess.py @@ -278,6 +278,7 @@ def add_label(key, value): expected_histogram = [ Sample('h_sum', labels, 6.0), + Sample('h_count', labels, 2.0), Sample('h_bucket', add_label('le', '0.005'), 0.0), Sample('h_bucket', add_label('le', '0.01'), 0.0), Sample('h_bucket', add_label('le', '0.025'), 0.0), @@ -293,7 +294,6 @@ def add_label(key, value): Sample('h_bucket', add_label('le', '7.5'), 2.0), Sample('h_bucket', add_label('le', '10.0'), 2.0), Sample('h_bucket', add_label('le', '+Inf'), 2.0), - Sample('h_count', labels, 2.0), ] self.assertEqual(metrics['h'].samples, expected_histogram) @@ -321,6 +321,7 @@ def add_label(key, value): expected_histogram = [ Sample('h_sum', {'view': 'view1'}, 6.0), + Sample('h_count', {'view': 'view1'}, 2.0), Sample('h_bucket', {'view': 'view1', 'le': '0.005'}, 0.0), Sample('h_bucket', {'view': 'view1', 'le': '0.01'}, 0.0), Sample('h_bucket', {'view': 'view1', 'le': '0.025'}, 0.0), @@ -336,8 +337,8 @@ def add_label(key, value): Sample('h_bucket', {'view': 'view1', 'le': '7.5'}, 2.0), Sample('h_bucket', {'view': 'view1', 'le': '10.0'}, 2.0), Sample('h_bucket', {'view': 'view1', 'le': '+Inf'}, 2.0), - Sample('h_count', {'view': 'view1'}, 2.0), Sample('h_sum', {'view': 'view2'}, 1.0), + Sample('h_count', {'view': 'view2'}, 1.0), Sample('h_bucket', {'view': 'view2', 'le': '0.005'}, 0.0), Sample('h_bucket', {'view': 'view2', 'le': '0.01'}, 0.0), Sample('h_bucket', {'view': 'view2', 'le': '0.025'}, 0.0), @@ -353,7 +354,6 @@ def add_label(key, value): Sample('h_bucket', {'view': 'view2', 'le': '7.5'}, 1.0), Sample('h_bucket', {'view': 'view2', 'le': '10.0'}, 1.0), Sample('h_bucket', {'view': 'view2', 'le': '+Inf'}, 1.0), - Sample('h_count', {'view': 'view2'}, 1.0), ] self.assertEqual(metrics['h'].samples, expected_histogram) @@ -435,6 +435,7 @@ def add_label(key, value): expected_histogram = [ Sample('h_sum', labels, 6.0), + Sample('h_count', labels, 2.0), Sample('h_bucket', add_label('le', '0.005'), 0.0), Sample('h_bucket', add_label('le', '0.01'), 0.0), Sample('h_bucket', add_label('le', '0.025'), 0.0), From b26db65054fb2c74241b56204aa3486f32636724 Mon Sep 17 00:00:00 2001 From: Stefano Rivera Date: Sun, 26 Apr 2026 16:58:51 -0400 Subject: [PATCH 2/2] Implement a Redis mode Signed-off-by: Stefano Rivera --- .github/workflows/ci.yaml | 3 + prometheus_client/metrics.py | 12 + prometheus_client/redis_collector.py | 107 +++++++ prometheus_client/values.py | 100 ++++++- pyproject.toml | 3 + tests/test_redis.py | 399 +++++++++++++++++++++++++++ tox.ini | 1 + 7 files changed, 621 insertions(+), 4 deletions(-) create mode 100644 prometheus_client/redis_collector.py create mode 100644 tests/test_redis.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index a7e4e094..4f8a7fb0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -59,6 +59,9 @@ jobs: uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: ${{ matrix.python-version }} + - name: Install Redis + run: | + apt-get -y install redis-server - name: Install dependencies run: | pip install --user tox "virtualenv<20.22.0" diff --git a/prometheus_client/metrics.py b/prometheus_client/metrics.py index 7af3e948..2f3c4845 100644 --- a/prometheus_client/metrics.py +++ b/prometheus_client/metrics.py @@ -207,6 +207,10 @@ def remove(self, *labelvalues: Any) -> None: warnings.warn( "Removal of labels has not been implemented in multi-process mode yet.", UserWarning) + if 'PROMETHEUS_REDIS_URL' in os.environ: + warnings.warn( + "Removal of labels has not been implemented in redis mode yet.", + UserWarning) if not self._labelnames: raise ValueError('No label names were set when constructing %s' % self) @@ -226,6 +230,10 @@ def remove_by_labels(self, labels: dict[str, str]) -> None: "Removal of labels has not been implemented in multi-process mode yet.", UserWarning ) + if 'PROMETHEUS_REDIS_URL' in os.environ: + warnings.warn( + "Removal of labels has not been implemented in redis mode yet.", + UserWarning) if not self._labelnames: raise ValueError('No label names were set when constructing %s' % self) @@ -258,6 +266,10 @@ def clear(self) -> None: warnings.warn( "Clearing labels has not been implemented in multi-process mode yet", UserWarning) + if 'PROMETHEUS_REDIS_URL' in os.environ: + warnings.warn( + "Clearing of labels has not been implemented in redis mode yet.", + UserWarning) with self._lock: self._metrics = {} diff --git a/prometheus_client/redis_collector.py b/prometheus_client/redis_collector.py new file mode 100644 index 00000000..9141a934 --- /dev/null +++ b/prometheus_client/redis_collector.py @@ -0,0 +1,107 @@ +from collections.abc import Iterable +import json +import os +from urllib.parse import urlsplit + +from .metrics_core import Metric +from .registry import Collector, CollectorRegistry +from .samples import Sample + + +def redis_client(): + """ + Create a redis client for PROMETHEUS_REDIS_URL. + + Configure the redis database via a URL in PROMETHEUS_REDIS_URL of the form + redis://localhost:6379/0 + """ + from redis import Redis + + parsed_url = urlsplit(os.environ["PROMETHEUS_REDIS_URL"]) + assert parsed_url.scheme == "redis" + assert parsed_url.path.startswith("/") + assert parsed_url.path[1:].isdigit() + port = parsed_url.port or 6379 + db = int(parsed_url.path[1:]) + return Redis(host=parsed_url.hostname, port=port, db=db) + + +class RedisCollector(Collector): + """Collector for redis mode.""" + + def __init__(self, registry: CollectorRegistry | None) -> None: + self._client = redis_client() + if registry: + registry.register(self) + + def _iter_values(self) -> Iterable[tuple[bytes, str]]: + cursor = 0 + while True: + cursor, keys = self._client.scan(cursor=cursor, match="value:*") + values = self._client.mget(keys) + yield from zip(keys, values) + if cursor == 0: + break + + def collect(self) -> Iterable[Metric]: + metrics: dict[str, Metric] = {} + histograms: set[str] = set() + + for key, value_s in self._iter_values(): + # FIXME: Catch ValueError here, just in case? + prefix_b, typ_b, mmap_key = key.split(b":", 2) + assert prefix_b == b"value" + typ = typ_b.decode() + value = float(value_s) + + metric_name, name, labels, help_text = json.loads(mmap_key) + + metric = metrics.get(metric_name) + if metric is None: + metric = Metric(metric_name, help_text, typ) + metrics[metric_name] = metric + if typ in ("histogram", "gaugehistogram"): + histograms.add(metric_name) + + metric.add_sample(name, labels, value) + + for name in histograms: + self._fix_histogram(metrics[name]) + + return metrics.values() + + def _fix_histogram(self, metric: Metric) -> None: + """ + Fix-up histogram samples. + + Sort the buckets as expected by a client, and accumulate the values. + The Histogram class is optimized to only increment the bucket that a + value first appears in, not larger ones that would also contain it. + """ + by_label: dict[tuple[tuple[str, ...], str], list[Sample]] = {} + + # Organize into lists of samples by label + for sample in metric.samples: + if "le" in sample.labels: + labels_without_le = sample.labels.copy() + labels_without_le.pop("le") + key = (tuple(labels_without_le.values()), sample.name) + else: + key = (tuple(sample.labels.values()), sample.name) + by_label.setdefault(key, []).append(sample) + + metric.samples = [] + + for (labels, name), samples in sorted(by_label.items()): + if name.endswith("_bucket"): + # Sort buckets within each label + samples.sort(key=lambda sample: float(sample.labels["le"])) + + # Accumulate values into larger buckets + value = 0.0 + for sample in samples: + value += sample.value + metric.samples.append(Sample(sample.name, sample.labels, value)) + + else: + metric.samples.extend(samples) diff --git a/prometheus_client/values.py b/prometheus_client/values.py index 6ff85e3b..54f576ce 100644 --- a/prometheus_client/values.py +++ b/prometheus_client/values.py @@ -1,11 +1,48 @@ import os from threading import Lock +from typing import Any, Protocol import warnings from .mmap_dict import mmap_key, MmapedDict +from .redis_collector import redis_client +from .samples import Exemplar -class MutexValue: +class Value(Protocol): + """Prometheus Client Metric implementation.""" + + _multiprocess: bool + + def __init__( + self, + typ: str, + metric_name: str, + name: str, + labelnames: list[str], + labelvalues: list[str], + help_text: str, + **kwargs: Any, + ) -> None: + """Initialize a metric.""" + + def inc(self, amount: float) -> None: + """Increment the metric by amount.""" + + def set(self, value: float, timestamp: float | None = None) -> None: + """Set the metric to value.""" + + def get(self) -> float: + """Get the current metric value.""" + + def set_exemplar(self, exemplar: Exemplar) -> None: + """Set an exemplar value.""" + exemplar # For vulture + + def get_exemplar(self) -> Exemplar | None: + """Get any set exemplar value.""" + + +class MutexValue(Value): """A float protected by a mutex.""" _multiprocess = False @@ -52,7 +89,7 @@ def MultiProcessValue(process_identifier=os.getpid): # This avoids the need to also have mutexes in __MmapDict. lock = Lock() - class MmapedValue: + class MmapedValue(Value): """A float protected by a mutex backed by a per-process mmaped file.""" _multiprocess = True @@ -125,12 +162,67 @@ def get_exemplar(self): return MmapedValue -def get_value_class(): +def RedisValue(): + """ + A value implementation that stores data in a redis/valkey database. + + Key scheme: + * value:typ:MMAP_KEY + """ + client = redis_client() + + class RedisValueImpl(Value): + """A float stored by redis.""" + + _multiprocess = False + + def __init__( + self, + typ: str, + metric_name: str, + name: str, + labelnames: list[str], + labelvalues: list[str], + help_text: str, + **kwargs: Any, + ) -> None: + key = mmap_key(metric_name, name, labelnames, labelvalues, help_text) + self._key = f"value:{typ}:{key}" + self._redis = client + self._redis.setnx(self._key, 0.0) + + def inc(self, amount: float) -> None: + self._redis.incrbyfloat(self._key, amount) + + def set(self, value: float, timestamp: float | None = None) -> None: + # TODO: Implement timestamps + self._redis.set(self._key, value) + + def get(self) -> float: + value = self._redis.get(self._key) + if value is None: + return 0.0 + return float(value) + + def set_exemplar(self, exemplar: Exemplar) -> None: + # TODO: Implement exemplars for redis. + return + + def get_exemplar(self) -> Exemplar | None: + # TODO: Implement exemplars for redis. + return None + + return RedisValueImpl + + +def get_value_class() -> type[Value]: # Should we enable multi-process mode? # This needs to be chosen before the first metric is constructed, # and as that may be in some arbitrary library the user/admin has # no control over we use an environment variable. - if 'prometheus_multiproc_dir' in os.environ or 'PROMETHEUS_MULTIPROC_DIR' in os.environ: + if "PROMETHEUS_REDIS_URL" in os.environ: + return RedisValue() + elif 'prometheus_multiproc_dir' in os.environ or 'PROMETHEUS_MULTIPROC_DIR' in os.environ: return MultiProcessValue() else: return MutexValue diff --git a/pyproject.toml b/pyproject.toml index 336cfb4f..915bf863 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,9 @@ aiohttp = [ django = [ "django", ] +redis = [ + "redis", +] [project.urls] Homepage = "https://github.com/prometheus/client_python" diff --git a/tests/test_redis.py b/tests/test_redis.py new file mode 100644 index 00000000..b580861b --- /dev/null +++ b/tests/test_redis.py @@ -0,0 +1,399 @@ +import os +import unittest +from collections.abc import Sequence +from typing import Any + +import pytest + +from prometheus_client import values +from prometheus_client.core import ( + CollectorRegistry, + Counter, + Gauge, + Histogram, + Sample, + Summary, +) +from prometheus_client.redis_collector import redis_client, RedisCollector +from prometheus_client.values import MutexValue, RedisValue, Value + +pytest.importorskip("redis") + +from redis import Redis + +class RedisTestCase(unittest.TestCase): + redis: Redis + + def setUp(self) -> None: + os.environ["PROMETHEUS_REDIS_URL"] = "redis://localhost/1" + client = redis_client() + if client.keys() != []: + raise pytest.skip( + "Redis database 1 has existing data. Refusing to clobber it." + ) + values.ValueClass = RedisValue() + + def tearDown(self) -> None: + redis_client().flushdb() + del os.environ["PROMETHEUS_REDIS_URL"] + values.ValueClass = MutexValue + + +class ValueTestCase(RedisTestCase): + def create_value( + self, + metric_name: str, + name: str | None = None, + type_: str = "counter", + labelnames: list[str] | None = None, + labelvalues: list[str] | None = None, + ) -> Value: + return values.ValueClass( + type_, + metric_name, + name or metric_name + "_total", + labelnames or [], + labelvalues or [], + "Help Text", + ) + + def test_initializes_value(self) -> None: + value = self.create_value("test1") + self.assertEqual(value.get(), 0.0) + + def test_sets_and_gets_value(self) -> None: + value = self.create_value("test2") + value.set(5) + self.assertEqual(value.get(), 5.0) + + def test_inc_value(self) -> None: + value = self.create_value("test3") + value.inc(3) + value.inc(5) + self.assertEqual(value.get(), 8.0) + + def test_differentiated_by_name(self) -> None: + v1 = self.create_value("value1") + v2 = self.create_value("value2") + v1.set(1) + v2.set(2) + self.assertEqual(v1.get(), 1.0) + self.assertEqual(v2.get(), 2.0) + + def test_differentiated_by_labels(self) -> None: + v1 = self.create_value("value3", labelnames=["a"], labelvalues=["1"]) + v2 = self.create_value("value3", labelnames=["a"], labelvalues=["2"]) + v1.set(1) + v2.set(2) + self.assertEqual(v1.get(), 1.0) + self.assertEqual(v2.get(), 2.0) + + +class TestRedis(RedisTestCase): + def setUp(self) -> None: + super().setUp() + self.registry = CollectorRegistry(support_collectors_without_names=True) + self.collector = RedisCollector(self.registry) + + def test_counter_adds(self) -> None: + c1 = Counter("c", "help", registry=None) + c2 = Counter("c", "help", registry=None) + self.assertEqual(0, self.registry.get_sample_value("c_total")) + c1.inc(1) + c2.inc(2) + self.assertEqual(3, self.registry.get_sample_value("c_total")) + + def test_summary_adds(self) -> None: + s1 = Summary("s", "help", registry=None) + s2 = Summary("s", "help", registry=None) + self.assertEqual(0, self.registry.get_sample_value("s_count")) + self.assertEqual(0, self.registry.get_sample_value("s_sum")) + s1.observe(1) + s2.observe(2) + self.assertEqual(2, self.registry.get_sample_value("s_count")) + self.assertEqual(3, self.registry.get_sample_value("s_sum")) + + def test_histogram_adds(self) -> None: + h1 = Histogram("h", "help", registry=None) + h2 = Histogram("h", "help", registry=None) + self.assertEqual(0, self.registry.get_sample_value("h_count")) + self.assertEqual(0, self.registry.get_sample_value("h_sum")) + self.assertEqual(0, self.registry.get_sample_value("h_bucket", {"le": "5.0"})) + h1.observe(1) + h2.observe(2) + self.assertEqual(2, self.registry.get_sample_value("h_count")) + self.assertEqual(3, self.registry.get_sample_value("h_sum")) + self.assertEqual(2, self.registry.get_sample_value("h_bucket", {"le": "5.0"})) + + def test_namespace_subsystem(self) -> None: + c1 = Counter("c", "help", registry=None, namespace="ns", subsystem="ss") + c1.inc(1) + self.assertEqual(1, self.registry.get_sample_value("ns_ss_c_total")) + + def test_collect(self) -> None: + labels = {i: i for i in "abcd"} + + def add_label(key: str, value: str) -> dict[str, str]: + l = labels.copy() + l[key] = value + return l + + c = Counter("c", "help", labelnames=labels.keys(), registry=None) + g = Gauge("g", "help", labelnames=labels.keys(), registry=None) + h = Histogram("h", "help", labelnames=labels.keys(), registry=None) + + c.labels(**labels).inc(1) + g.labels(**labels).set(1) + h.labels(**labels).observe(1) + + c.labels(**labels).inc(1) + g.labels(**labels).set(1) + h.labels(**labels).observe(5) + + metrics = {m.name: m for m in self.collector.collect()} + + self.assertEqual(metrics["c"].samples, [Sample("c_total", labels, 2.0)]) + + expected_histogram = [ + Sample("h_bucket", add_label("le", "0.005"), 0.0), + Sample("h_bucket", add_label("le", "0.01"), 0.0), + Sample("h_bucket", add_label("le", "0.025"), 0.0), + Sample("h_bucket", add_label("le", "0.05"), 0.0), + Sample("h_bucket", add_label("le", "0.075"), 0.0), + Sample("h_bucket", add_label("le", "0.1"), 0.0), + Sample("h_bucket", add_label("le", "0.25"), 0.0), + Sample("h_bucket", add_label("le", "0.5"), 0.0), + Sample("h_bucket", add_label("le", "0.75"), 0.0), + Sample("h_bucket", add_label("le", "1.0"), 1.0), + Sample("h_bucket", add_label("le", "2.5"), 1.0), + Sample("h_bucket", add_label("le", "5.0"), 2.0), + Sample("h_bucket", add_label("le", "7.5"), 2.0), + Sample("h_bucket", add_label("le", "10.0"), 2.0), + Sample("h_bucket", add_label("le", "+Inf"), 2.0), + Sample("h_count", labels, 2.0), + Sample("h_sum", labels, 6.0), + ] + + self.assertEqual(metrics["h"].samples, expected_histogram) + + def test_collect_histogram_ordering(self) -> None: + labels = {i: i for i in "abcd"} + + def add_label(key: str, value: str) -> dict[str, str]: + l = labels.copy() + l[key] = value + return l + + h = Histogram("h", "help", labelnames=["view"], registry=None) + + h.labels(view="view1").observe(1) + + h.labels(view="view1").observe(5) + h.labels(view="view2").observe(1) + + metrics = {m.name: m for m in self.collector.collect()} + + expected_histogram = [ + Sample("h_bucket", {"view": "view1", "le": "0.005"}, 0.0), + Sample("h_bucket", {"view": "view1", "le": "0.01"}, 0.0), + Sample("h_bucket", {"view": "view1", "le": "0.025"}, 0.0), + Sample("h_bucket", {"view": "view1", "le": "0.05"}, 0.0), + Sample("h_bucket", {"view": "view1", "le": "0.075"}, 0.0), + Sample("h_bucket", {"view": "view1", "le": "0.1"}, 0.0), + Sample("h_bucket", {"view": "view1", "le": "0.25"}, 0.0), + Sample("h_bucket", {"view": "view1", "le": "0.5"}, 0.0), + Sample("h_bucket", {"view": "view1", "le": "0.75"}, 0.0), + Sample("h_bucket", {"view": "view1", "le": "1.0"}, 1.0), + Sample("h_bucket", {"view": "view1", "le": "2.5"}, 1.0), + Sample("h_bucket", {"view": "view1", "le": "5.0"}, 2.0), + Sample("h_bucket", {"view": "view1", "le": "7.5"}, 2.0), + Sample("h_bucket", {"view": "view1", "le": "10.0"}, 2.0), + Sample("h_bucket", {"view": "view1", "le": "+Inf"}, 2.0), + Sample("h_count", {"view": "view1"}, 2.0), + Sample("h_sum", {"view": "view1"}, 6.0), + Sample("h_bucket", {"view": "view2", "le": "0.005"}, 0.0), + Sample("h_bucket", {"view": "view2", "le": "0.01"}, 0.0), + Sample("h_bucket", {"view": "view2", "le": "0.025"}, 0.0), + Sample("h_bucket", {"view": "view2", "le": "0.05"}, 0.0), + Sample("h_bucket", {"view": "view2", "le": "0.075"}, 0.0), + Sample("h_bucket", {"view": "view2", "le": "0.1"}, 0.0), + Sample("h_bucket", {"view": "view2", "le": "0.25"}, 0.0), + Sample("h_bucket", {"view": "view2", "le": "0.5"}, 0.0), + Sample("h_bucket", {"view": "view2", "le": "0.75"}, 0.0), + Sample("h_bucket", {"view": "view2", "le": "1.0"}, 1.0), + Sample("h_bucket", {"view": "view2", "le": "2.5"}, 1.0), + Sample("h_bucket", {"view": "view2", "le": "5.0"}, 1.0), + Sample("h_bucket", {"view": "view2", "le": "7.5"}, 1.0), + Sample("h_bucket", {"view": "view2", "le": "10.0"}, 1.0), + Sample("h_bucket", {"view": "view2", "le": "+Inf"}, 1.0), + Sample("h_count", {"view": "view2"}, 1.0), + Sample("h_sum", {"view": "view2"}, 1.0), + ] + + self.assertEqual(metrics["h"].samples, expected_histogram) + + def test_restrict(self) -> None: + labels = {i: i for i in "abcd"} + + def add_label(key: str, value: str) -> dict[str, str]: + l = labels.copy() + l[key] = value + return l + + c = Counter("c", "help", labelnames=labels.keys(), registry=None) + g = Gauge("g", "help", labelnames=labels.keys(), registry=None) + + c.labels(**labels).inc(1) + g.labels(**labels).set(1) + + c.labels(**labels).inc(1) + g.labels(**labels).set(1) + + metrics = { + m.name: m for m in self.registry.restricted_registry(["c_total"]).collect() + } + + self.assertEqual(metrics.keys(), {"c"}) + + self.assertEqual(metrics["c"].samples, [Sample("c_total", labels, 2.0)]) + + def test_collect_preserves_help(self) -> None: + labels = {i: i for i in "abcd"} + + c = Counter("c", "c help", labelnames=labels.keys(), registry=None) + g = Gauge("g", "g help", labelnames=labels.keys(), registry=None) + h = Histogram("h", "h help", labelnames=labels.keys(), registry=None) + + c.labels(**labels).inc(1) + g.labels(**labels).set(1) + h.labels(**labels).observe(1) + + c.labels(**labels).inc(1) + g.labels(**labels).set(1) + h.labels(**labels).observe(5) + + metrics = {m.name: m for m in self.collector.collect()} + + self.assertEqual(metrics["c"].documentation, "c help") + self.assertEqual(metrics["g"].documentation, "g help") + self.assertEqual(metrics["h"].documentation, "h help") + + def test_child_name_is_built_once_with_namespace_subsystem_unit(self) -> None: + """ + Repro for #1035: + In multiprocess mode, child metrics must NOT rebuild the full name + (namespace/subsystem/unit) a second time. The exported family name should + be built once, and Counter samples should use "_total". + """ + from prometheus_client import Counter + + class CustomCounter(Counter): + def __init__( + self, + name: str, + documentation: str, + labelnames: Sequence[str] = (), + namespace: str = "mydefaultnamespace", + subsystem: str = "mydefaultsubsystem", + unit: str = "", + registry: CollectorRegistry | None = None, + _labelvalues: Sequence[str] | None = None, + ): + # Intentionally provide non-empty defaults to trigger the bug path. + super().__init__( + name=name, + documentation=documentation, + labelnames=labelnames, + namespace=namespace, + subsystem=subsystem, + unit=unit, + registry=registry, + _labelvalues=_labelvalues, + ) + + # Create a Counter with explicit namespace/subsystem/unit + c = CustomCounter( + name="m", + documentation="help", + labelnames=("status", "method"), + namespace="ns", + subsystem="ss", + unit="seconds", # avoid '_total_total' confusion + registry=None, # not registered in local registry in multiprocess mode + ) + + # Create two labeled children + c.labels(status="200", method="GET").inc() + c.labels(status="404", method="POST").inc() + + # Collect from the multiprocess collector initialized in setUp() + metrics = {m.name: m for m in self.collector.collect()} + + # Family name should be built once (no '_total' in family name) + expected_family = "ns_ss_m_seconds" + self.assertIn(expected_family, metrics, f"missing family {expected_family}") + + # Counter samples must use '_total' + mf = metrics[expected_family] + sample_names = {s.name for s in mf.samples} + self.assertTrue( + all(name == expected_family + "_total" for name in sample_names), + f"unexpected sample names: {sample_names}", + ) + + # Ensure no double-built prefix sneaks in (the original bug) + bad_prefix = "mydefaultnamespace_mydefaultsubsystem_" + all_names = {mf.name, *sample_names} + self.assertTrue( + all(not n.startswith(bad_prefix) for n in all_names), + f"found double-built name(s): {[n for n in all_names if n.startswith(bad_prefix)]}", + ) + + def test_child_preserves_parent_context_for_subclasses(self) -> None: + """ + Ensure child metrics preserve parent's namespace/subsystem/unit information + so that subclasses can correctly use these parameters in their logic. + """ + + class ContextAwareCounter(Counter): + def __init__( + self, + name: str, + documentation: str, + labelnames: Sequence[str] = (), + namespace: str = "", + subsystem: str = "", + unit: str = "", + **kwargs: Any, + ): + self.context = { + "namespace": namespace, + "subsystem": subsystem, + "unit": unit, + } + super().__init__( + name, + documentation, + labelnames=labelnames, + namespace=namespace, + subsystem=subsystem, + unit=unit, + **kwargs, + ) + + parent = ContextAwareCounter( + "m", + "help", + labelnames=["status"], + namespace="prod", + subsystem="api", + unit="seconds", + registry=None, + ) + + child = parent.labels(status="200") + + # Verify that child retains parent's context + self.assertEqual(child.context["namespace"], "prod") + self.assertEqual(child.context["subsystem"], "api") + self.assertEqual(child.context["unit"], "seconds") diff --git a/tox.ini b/tox.ini index 992bd0a7..c9721de6 100644 --- a/tox.ini +++ b/tox.ini @@ -8,6 +8,7 @@ deps = pytest pytest-benchmark attrs + redis {py3.9,pypy3.9}: twisted {py3.9,pypy3.9}: aiohttp {py3.9,pypy3.9}: django