diff --git a/src/msgspec_ext/settings.py b/src/msgspec_ext/settings.py index 0b2435a..84612d1 100644 --- a/src/msgspec_ext/settings.py +++ b/src/msgspec_ext/settings.py @@ -1,6 +1,7 @@ """Optimized settings management using msgspec.Struct and bulk JSON decoding.""" import os +import types from typing import Annotated, Any, ClassVar, Union, get_args, get_origin import msgspec @@ -133,7 +134,8 @@ class SettingsConfigDict(msgspec.Struct): env_file_encoding: str = "utf-8" case_sensitive: bool = False env_prefix: str = "" - env_nested_delimiter: str = "__" + env_nested_delimiter: str | None = None + env_nested_max_depth: int = 0 class BaseSettings: @@ -226,14 +228,17 @@ def _create_struct_class(cls): if field_name == "model_config": continue + # Resolve BaseSettings subclass types to struct equivalents + resolved_type = cls._resolve_field_type(field_type) + # Get default value from class attribute if exists if hasattr(cls, field_name): default_value = getattr(cls, field_name) # Field with default: (name, type, default) - goes to optional - optional_fields.append((field_name, field_type, default_value)) + optional_fields.append((field_name, resolved_type, default_value)) else: # Required field: (name, type) - goes to required - required_fields.append((field_name, field_type)) + required_fields.append((field_name, resolved_type)) # IMPORTANT: Required fields must come before optional fields # This avoids "Required field cannot follow optional fields" error @@ -251,6 +256,58 @@ def _create_struct_class(cls): return struct_cls + @classmethod + def _resolve_field_type(cls, field_type): + """Convert BaseSettings subclass types to their struct equivalents. + + Recursively walks the type tree to handle: + - Direct: DatabaseSettings → DatabaseStruct + - Union/Optional: DatabaseSettings | None → DatabaseStruct | None + - Generic: list[DatabaseSettings] → list[DatabaseStruct] + - Annotated: Annotated[DatabaseSettings, Meta] → Annotated[DatabaseStruct, Meta] + """ + # Direct BaseSettings subclass + if ( + isinstance(field_type, type) + and field_type is not BaseSettings + and issubclass(field_type, BaseSettings) + ): + return field_type._get_or_create_struct_class() + + origin = get_origin(field_type) + if origin is None: + return field_type + + args = get_args(field_type) + if not args: + return field_type + + # Handle Union / Optional (typing.Union and Python 3.10+ X | Y) + if origin is Union or origin is types.UnionType: + resolved = tuple(cls._resolve_field_type(a) for a in args) + if resolved != args: + result = resolved[0] + for t in resolved[1:]: + result = result | t + return result + return field_type + + # Handle Annotated[BaseType, *metadata] + if origin is Annotated: + resolved_base = cls._resolve_field_type(args[0]) + if resolved_base is not args[0]: + return Annotated[(resolved_base, *args[1:])] + return field_type + + # Handle generic types: list[X], dict[K, V], set[X], tuple[X, ...], etc. + resolved = tuple(cls._resolve_field_type(a) for a in args) + if resolved != args: + return ( + origin[resolved[0]] if len(resolved) == 1 else origin[tuple(resolved)] + ) + + return field_type + @classmethod def _inject_helper_methods(cls, struct_cls): """Inject helper methods into the dynamically created Struct.""" @@ -365,7 +422,14 @@ def _load_env_files(cls): @classmethod def _collect_env_values(cls, struct_cls) -> dict[str, Any]: - """Collect environment variable values for all fields. + """Collect environment variable values for all fields.""" + if cls.model_config.env_nested_delimiter is not None: + return cls._collect_nested_env_values(struct_cls) + return cls._collect_flat_env_values(struct_cls) + + @classmethod + def _collect_flat_env_values(cls, struct_cls) -> dict[str, Any]: + """Collect flat environment variable values for all fields. Returns dict with field_name -> converted_value. Highly optimized with cached field->env name mapping. @@ -399,6 +463,100 @@ def _collect_env_values(cls, struct_cls) -> dict[str, Any]: return env_dict + @classmethod + def _collect_nested_env_values(cls, struct_cls) -> dict[str, Any]: + """Collect env values and unfold nested keys by delimiter.""" + delimiter = cls.model_config.env_nested_delimiter + prefix = cls.model_config.env_prefix + case_sensitive = cls.model_config.case_sensitive + max_depth = cls.model_config.env_nested_max_depth + + # Precompute prefix and delimiter for case-insensitive matching + if not case_sensitive: + prefix = prefix.upper() + delimiter = delimiter.upper() + prefix_len = len(prefix) + + result = {} + struct_fields = set(struct_cls.__struct_fields__) + + for env_key, env_value in os.environ.items(): + # Normalize key for matching + key = env_key if case_sensitive else env_key.upper() + + # Check and strip prefix + if prefix: + if not key.startswith(prefix): + continue + key = key[prefix_len:] + + # Split by delimiter (respect max_depth) + if max_depth > 0: + parts = key.split(delimiter, maxsplit=max_depth) + else: + parts = key.split(delimiter) + + # Normalize parts to field names (lowercase for case-insensitive) + if not case_sensitive: + parts = [p.lower() for p in parts] + + # Skip if root field doesn't exist in struct + if parts[0] not in struct_fields: + continue + + # Resolve leaf type and preprocess value + leaf_type = cls._resolve_leaf_type(struct_cls, parts) + if leaf_type is not None: + converted = cls._preprocess_env_value(env_value, leaf_type) + else: + converted = env_value + + # Set value in nested dict + cls._set_nested_value(result, parts, converted) + + return result + + @staticmethod + def _set_nested_value(target: dict, parts: list[str], value: Any) -> None: + """Set a value in a nested dict by key path.""" + for part in parts[:-1]: + if part not in target or not isinstance(target[part], dict): + target[part] = {} + target = target[part] + target[parts[-1]] = value + + @classmethod + def _resolve_leaf_type(cls, struct_cls, parts: list[str]) -> type | None: + """Walk struct annotations to find the leaf field type.""" + current = struct_cls + for i, part in enumerate(parts): + annotations = current.__annotations__ + if part not in annotations: + return None + field_type = annotations[part] + if i == len(parts) - 1: + return field_type + # Navigate into nested struct + inner = cls._unwrap_struct_type(field_type) + if inner is not None and hasattr(inner, "__struct_fields__"): + current = inner + else: + return None + return None + + @classmethod + def _unwrap_struct_type(cls, field_type) -> type | None: + """Extract the core struct type, unwrapping Optional if needed.""" + if isinstance(field_type, type) and hasattr(field_type, "__struct_fields__"): + return field_type + origin = get_origin(field_type) + if origin is Union or origin is types.UnionType: + args = get_args(field_type) + non_none = [a for a in args if a is not type(None)] + if len(non_none) == 1 and hasattr(non_none[0], "__struct_fields__"): + return non_none[0] + return None + @classmethod def _get_env_name(cls, field_name: str) -> str: """Convert Python field name to environment variable name. diff --git a/tests/test_settings.py b/tests/test_settings.py index 06dee96..d847735 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -392,3 +392,261 @@ class AppSettings(BaseSettings): assert settings.port == 7000 # explicit overrides env finally: os.environ.pop("PORT", None) + + +def test_nested_env_vars_basic(): + """Test basic nested struct from environment variables.""" + os.environ["DATABASE__HOST"] = "localhost" + os.environ["DATABASE__PORT"] = "5432" + + try: + + class DatabaseSettings(BaseSettings): + host: str = "127.0.0.1" + port: int = 3306 + + class AppSettings(BaseSettings): + model_config = SettingsConfigDict(env_nested_delimiter="__") + + name: str = "app" + database: DatabaseSettings + + settings = AppSettings() + assert settings.name == "app" + assert settings.database.host == "localhost" + assert settings.database.port == 5432 + finally: + os.environ.pop("DATABASE__HOST", None) + os.environ.pop("DATABASE__PORT", None) + + +def test_nested_env_vars_with_prefix(): + """Test nested env vars with env_prefix.""" + os.environ["APP_DATABASE__HOST"] = "db.example.com" + os.environ["APP_DATABASE__PORT"] = "5433" + os.environ["APP_NAME"] = "myapp" + + try: + + class DatabaseSettings(BaseSettings): + host: str = "localhost" + port: int = 5432 + + class AppSettings(BaseSettings): + model_config = SettingsConfigDict( + env_prefix="APP_", env_nested_delimiter="__" + ) + + name: str = "default" + database: DatabaseSettings + + settings = AppSettings() + assert settings.name == "myapp" + assert settings.database.host == "db.example.com" + assert settings.database.port == 5433 + finally: + os.environ.pop("APP_DATABASE__HOST", None) + os.environ.pop("APP_DATABASE__PORT", None) + os.environ.pop("APP_NAME", None) + + +def test_nested_env_vars_deep(): + """Test 3 levels of nesting.""" + os.environ["DATABASE__POSTGRES__HOST"] = "pg.local" + os.environ["DATABASE__POSTGRES__PORT"] = "5432" + + try: + + class PostgresSettings(BaseSettings): + host: str = "localhost" + port: int = 5432 + + class DatabaseSettings(BaseSettings): + postgres: PostgresSettings + + class AppSettings(BaseSettings): + model_config = SettingsConfigDict(env_nested_delimiter="__") + + database: DatabaseSettings + + settings = AppSettings() + assert settings.database.postgres.host == "pg.local" + assert settings.database.postgres.port == 5432 + finally: + os.environ.pop("DATABASE__POSTGRES__HOST", None) + os.environ.pop("DATABASE__POSTGRES__PORT", None) + + +def test_nested_env_vars_max_depth(): + """Test env_nested_max_depth limits splitting.""" + os.environ["DATABASE__POSTGRES__HOST"] = "pg.local" + + try: + + class DatabaseSettings(BaseSettings): + # With max_depth=1, "POSTGRES__HOST" stays as single key + postgres__host: str = "default" + + class AppSettings(BaseSettings): + model_config = SettingsConfigDict( + env_nested_delimiter="__", env_nested_max_depth=1 + ) + + database: DatabaseSettings + + settings = AppSettings() + # max_depth=1: DATABASE__POSTGRES__HOST splits into ["database", "postgres__host"] + assert settings.database.postgres__host == "pg.local" + finally: + os.environ.pop("DATABASE__POSTGRES__HOST", None) + + +def test_nested_delimiter_none_disables(): + """Test that env_nested_delimiter=None uses flat lookup only.""" + os.environ["DATABASE__HOST"] = "should-be-ignored" + os.environ["NAME"] = "flat-app" + + try: + + class AppSettings(BaseSettings): + model_config = SettingsConfigDict(env_nested_delimiter=None) + + name: str = "default" + + settings = AppSettings() + assert settings.name == "flat-app" + # DATABASE__HOST is not a field name, so it's ignored + finally: + os.environ.pop("DATABASE__HOST", None) + os.environ.pop("NAME", None) + + +def test_nested_env_vars_case_insensitive(): + """Test nested env vars with case_sensitive=False (default).""" + os.environ["DATABASE__HOST"] = "ci-host" + os.environ["DATABASE__PORT"] = "3307" + + try: + + class DatabaseSettings(BaseSettings): + host: str = "localhost" + port: int = 5432 + + class AppSettings(BaseSettings): + model_config = SettingsConfigDict( + case_sensitive=False, env_nested_delimiter="__" + ) + + database: DatabaseSettings + + settings = AppSettings() + assert settings.database.host == "ci-host" + assert settings.database.port == 3307 + finally: + os.environ.pop("DATABASE__HOST", None) + os.environ.pop("DATABASE__PORT", None) + + +def test_nested_env_vars_from_dotenv(): + """Test loading nested env vars from .env file.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".env", delete=False) as f: + f.write("DATABASE__HOST=dotenv-host\n") + f.write("DATABASE__PORT=6543\n") + f.write("APP_NAME=dotenv-app\n") + env_file_path = f.name + + try: + + class DatabaseSettings(BaseSettings): + host: str = "localhost" + port: int = 5432 + + class AppSettings(BaseSettings): + model_config = SettingsConfigDict( + env_file=env_file_path, env_nested_delimiter="__" + ) + + app_name: str = "default" + database: DatabaseSettings + + settings = AppSettings() + assert settings.app_name == "dotenv-app" + assert settings.database.host == "dotenv-host" + assert settings.database.port == 6543 + finally: + os.environ.pop("DATABASE__HOST", None) + os.environ.pop("DATABASE__PORT", None) + os.environ.pop("APP_NAME", None) + Path(env_file_path).unlink(missing_ok=True) + + +def test_nested_with_optional_struct(): + """Test Optional nested struct field.""" + os.environ["DATABASE__HOST"] = "opt-host" + + try: + + class DatabaseSettings(BaseSettings): + host: str = "localhost" + port: int = 5432 + + class AppSettings(BaseSettings): + model_config = SettingsConfigDict(env_nested_delimiter="__") + + name: str = "app" + database: DatabaseSettings | None = None + + settings = AppSettings() + assert settings.database.host == "opt-host" + assert settings.database.port == 5432 + finally: + os.environ.pop("DATABASE__HOST", None) + + +def test_nested_with_all_defaults(): + """Test nested struct where all sub-fields have defaults.""" + # No DATABASE__ env vars set — nested struct should still work via defaults + # when passed as an empty dict + os.environ["DATABASE__HOST"] = "explicit" + + try: + + class DatabaseSettings(BaseSettings): + host: str = "localhost" + port: int = 5432 + + class AppSettings(BaseSettings): + model_config = SettingsConfigDict(env_nested_delimiter="__") + + name: str = "app" + database: DatabaseSettings + + settings = AppSettings() + assert settings.database.host == "explicit" + assert settings.database.port == 5432 + finally: + os.environ.pop("DATABASE__HOST", None) + + +def test_nested_custom_delimiter(): + """Test using a custom delimiter other than __.""" + os.environ["DATABASE.HOST"] = "dot-host" + os.environ["DATABASE.PORT"] = "9999" + + try: + + class DatabaseSettings(BaseSettings): + host: str = "localhost" + port: int = 5432 + + class AppSettings(BaseSettings): + model_config = SettingsConfigDict(env_nested_delimiter=".") + + database: DatabaseSettings + + settings = AppSettings() + assert settings.database.host == "dot-host" + assert settings.database.port == 9999 + finally: + os.environ.pop("DATABASE.HOST", None) + os.environ.pop("DATABASE.PORT", None)