diff --git a/src/promptfoo/telemetry.py b/src/promptfoo/telemetry.py index f1d8077..731663d 100644 --- a/src/promptfoo/telemetry.py +++ b/src/promptfoo/telemetry.py @@ -9,6 +9,7 @@ import os import platform import sys +import threading import uuid from pathlib import Path from typing import Any @@ -82,9 +83,10 @@ def _write_global_config(config: dict[str, Any]) -> None: pass # Silently fail - telemetry should never break the CLI -def _get_user_id() -> str: +def _get_user_id(config: dict[str, Any] | None = None) -> str: """Get or create a unique user ID stored in the global config.""" - config = _read_global_config() + if config is None: + config = _read_global_config() user_id = config.get("id") if not user_id: @@ -95,9 +97,10 @@ def _get_user_id() -> str: return user_id -def _get_user_email() -> str | None: +def _get_user_email(config: dict[str, Any] | None = None) -> str | None: """Get the user email from the global config if set.""" - config = _read_global_config() + if config is None: + config = _read_global_config() account = config.get("account", {}) return account.get("email") if isinstance(account, dict) else None @@ -127,8 +130,9 @@ def _ensure_initialized(self) -> None: return try: - self._user_id = _get_user_id() - self._email = _get_user_email() + config = _read_global_config() + self._user_id = _get_user_id(config) + self._email = _get_user_email(config) self._client = Posthog( project_api_key=_POSTHOG_KEY, host=_POSTHOG_HOST, @@ -182,15 +186,20 @@ def shutdown(self) -> None: # Global singleton instance _telemetry: _Telemetry | None = None +_telemetry_lock = threading.Lock() def _get_telemetry() -> _Telemetry: """Get the global telemetry instance.""" global _telemetry - if _telemetry is None: - _telemetry = _Telemetry() - atexit.register(_telemetry.shutdown) - return _telemetry + if _telemetry is not None: + return _telemetry + + with _telemetry_lock: + if _telemetry is None: + _telemetry = _Telemetry() + atexit.register(_telemetry.shutdown) + return _telemetry def record_wrapper_used(method: str) -> None: diff --git a/tests/test_environment.py b/tests/test_environment.py index 5b2b731..e99ed7d 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -31,6 +31,13 @@ def test_read_probe_file_returns_none_when_missing(self, tmp_path: Path) -> None """Missing probe files return None.""" assert _read_probe_file(tmp_path / "missing") is None + def test_read_probe_file_returns_content_when_readable(self, tmp_path: Path) -> None: + """Readable probe files return their text content.""" + probe_file = tmp_path / "probe" + probe_file.write_text("value") + + assert _read_probe_file(probe_file) == "value" + def test_read_probe_file_returns_none_when_unreadable(self, tmp_path: Path) -> None: """Unreadable probe files return None instead of raising.""" probe_file = tmp_path / "probe" @@ -245,6 +252,7 @@ def test_detect_kubernetes_from_env(self, monkeypatch: pytest.MonkeyPatch) -> No mock_path.return_value.exists.return_value = False is_docker, is_k8s = _detect_container() + assert is_docker is False assert is_k8s is True def test_detect_container_returns_tuple(self) -> None: diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py index 591d414..2deef3f 100644 --- a/tests/test_telemetry.py +++ b/tests/test_telemetry.py @@ -9,6 +9,7 @@ """ import os +import threading from pathlib import Path from unittest import mock @@ -17,11 +18,13 @@ from promptfoo.telemetry import ( _get_config_dir, _get_env_bool, + _get_telemetry, _get_user_email, _get_user_id, _is_ci, _read_global_config, _Telemetry, + _telemetry_lock, _write_global_config, record_wrapper_used, ) @@ -268,6 +271,23 @@ def test_record_initializes_client(self, tmp_path: Path) -> None: assert telemetry._client is mock_client mock_client.capture.assert_called_once() + def test_initialization_reads_global_config_once(self) -> None: + """Initialization shares one config read across user identity lookups.""" + config = {"id": "test-user-id", "account": {"email": "test@example.com"}} + + with ( + mock.patch.dict(os.environ, {}, clear=True), + mock.patch("promptfoo.telemetry._read_global_config", return_value=config) as mock_read_config, + mock.patch("promptfoo.telemetry.Posthog") as mock_posthog, + ): + telemetry = _Telemetry() + telemetry._ensure_initialized() + + mock_read_config.assert_called_once_with() + assert telemetry._user_id == "test-user-id" + assert telemetry._email == "test@example.com" + mock_posthog.assert_called_once() + def test_record_enriches_properties(self, tmp_path: Path) -> None: """Test record adds enriched properties.""" config_file = tmp_path / "promptfoo.yaml" @@ -471,3 +491,38 @@ def test_record_wrapper_used_disabled(self, monkeypatch: pytest.MonkeyPatch) -> with mock.patch("promptfoo.telemetry._telemetry", None): # Should not raise or make any calls record_wrapper_used("global") + + def test_get_telemetry_guards_singleton_initialization_with_lock(self) -> None: + """Singleton construction waits on its lock and registers shutdown once.""" + started = threading.Event() + finished = threading.Event() + instance = mock.Mock(spec=_Telemetry) + results: list[_Telemetry] = [] + + def initialize() -> None: + started.set() + results.append(_get_telemetry()) + finished.set() + + with ( + mock.patch("promptfoo.telemetry._telemetry", None), + mock.patch("promptfoo.telemetry._Telemetry", return_value=instance) as mock_telemetry, + mock.patch("promptfoo.telemetry.atexit.register") as mock_register, + ): + _telemetry_lock.acquire() + try: + worker = threading.Thread(target=initialize) + worker.start() + assert started.wait(timeout=1) + assert finished.wait(timeout=0.05) is False + mock_telemetry.assert_not_called() + finally: + _telemetry_lock.release() + + worker.join(timeout=1) + assert worker.is_alive() is False + + assert results == [instance] + assert _get_telemetry() is instance + mock_telemetry.assert_called_once_with() + mock_register.assert_called_once_with(instance.shutdown)