diff --git a/aikido_zen/__init__.py b/aikido_zen/__init__.py index abf507e0..81c07d60 100644 --- a/aikido_zen/__init__.py +++ b/aikido_zen/__init__.py @@ -24,6 +24,9 @@ from aikido_zen.helpers.aikido_disabled_flag_active import aikido_disabled_flag_active +VALID_MODES = ("daemon", "daemon_only", "daemon_disabled") + + def protect(mode="daemon", token=""): """ Mode can be set to : @@ -32,6 +35,10 @@ def protect(mode="daemon", token=""): - daemon_disabled : This will import sinks/sources but won't start a background process Protect user's application """ + if mode not in VALID_MODES: + raise ValueError( + f"Invalid mode {mode!r}, expected one of {VALID_MODES}. To pass a token, use protect(token=...)" + ) if aikido_disabled_flag_active(): # Do not run any aikido code when the disabled flag is on return diff --git a/aikido_zen/init_test.py b/aikido_zen/init_test.py index 5a99f7e9..f9ec3ae1 100644 --- a/aikido_zen/init_test.py +++ b/aikido_zen/init_test.py @@ -17,3 +17,13 @@ def test_protect_with_django(monkeypatch, caplog): def test_protect_sets_token(): aikido_zen.protect(token="MY_TOKEN_1") assert get_token_from_env().token == "MY_TOKEN_1" + + +def test_protect_rejects_invalid_mode(): + with pytest.raises(ValueError, match=r"Invalid mode .*protect\(token=\.\.\.\)"): + aikido_zen.protect("AIK_RUNTIME_some-token-string") + + +@pytest.mark.parametrize("mode", ["daemon", "daemon_only", "daemon_disabled"]) +def test_protect_accepts_valid_modes(mode): + aikido_zen.protect(mode=mode) diff --git a/aikido_zen/sinks/tests/langchain_test.py b/aikido_zen/sinks/tests/langchain_test.py index e1b7a694..55d1f415 100644 --- a/aikido_zen/sinks/tests/langchain_test.py +++ b/aikido_zen/sinks/tests/langchain_test.py @@ -5,7 +5,7 @@ import aikido_zen from aikido_zen.thread.thread_cache import get_cache -aikido_zen.protect(mode="daemon-disabled") +aikido_zen.protect(mode="daemon_disabled") skip_no_openai_key = pytest.mark.skipif( "OPENAI_API_KEY" not in os.environ,