Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 162 additions & 4 deletions src/msgspec_ext/settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Optimized settings management using msgspec.Struct and bulk JSON decoding."""

import os
import types
from typing import Annotated, Any, ClassVar, Union, get_args, get_origin

import msgspec
Expand Down Expand Up @@ -133,7 +134,8 @@ class SettingsConfigDict(msgspec.Struct):
env_file_encoding: str = "utf-8"
case_sensitive: bool = False
env_prefix: str = ""
env_nested_delimiter: str = "__"
env_nested_delimiter: str | None = None
env_nested_max_depth: int = 0


class BaseSettings:
Expand Down Expand Up @@ -226,14 +228,17 @@ def _create_struct_class(cls):
if field_name == "model_config":
continue

# Resolve BaseSettings subclass types to struct equivalents
resolved_type = cls._resolve_field_type(field_type)

# Get default value from class attribute if exists
if hasattr(cls, field_name):
default_value = getattr(cls, field_name)
# Field with default: (name, type, default) - goes to optional
optional_fields.append((field_name, field_type, default_value))
optional_fields.append((field_name, resolved_type, default_value))
else:
# Required field: (name, type) - goes to required
required_fields.append((field_name, field_type))
required_fields.append((field_name, resolved_type))

# IMPORTANT: Required fields must come before optional fields
# This avoids "Required field cannot follow optional fields" error
Expand All @@ -251,6 +256,58 @@ def _create_struct_class(cls):

return struct_cls

@classmethod
def _resolve_field_type(cls, field_type):
"""Convert BaseSettings subclass types to their struct equivalents.

Recursively walks the type tree to handle:
- Direct: DatabaseSettings → DatabaseStruct
- Union/Optional: DatabaseSettings | None → DatabaseStruct | None
- Generic: list[DatabaseSettings] → list[DatabaseStruct]
- Annotated: Annotated[DatabaseSettings, Meta] → Annotated[DatabaseStruct, Meta]
"""
# Direct BaseSettings subclass
if (
isinstance(field_type, type)
and field_type is not BaseSettings
and issubclass(field_type, BaseSettings)
):
return field_type._get_or_create_struct_class()

origin = get_origin(field_type)
if origin is None:
return field_type

args = get_args(field_type)
if not args:
return field_type

# Handle Union / Optional (typing.Union and Python 3.10+ X | Y)
if origin is Union or origin is types.UnionType:
resolved = tuple(cls._resolve_field_type(a) for a in args)
if resolved != args:
result = resolved[0]
for t in resolved[1:]:
result = result | t
return result
return field_type

# Handle Annotated[BaseType, *metadata]
if origin is Annotated:
resolved_base = cls._resolve_field_type(args[0])
if resolved_base is not args[0]:
return Annotated[(resolved_base, *args[1:])]
return field_type

# Handle generic types: list[X], dict[K, V], set[X], tuple[X, ...], etc.
resolved = tuple(cls._resolve_field_type(a) for a in args)
if resolved != args:
return (
origin[resolved[0]] if len(resolved) == 1 else origin[tuple(resolved)]
)

return field_type

@classmethod
def _inject_helper_methods(cls, struct_cls):
"""Inject helper methods into the dynamically created Struct."""
Expand Down Expand Up @@ -365,7 +422,14 @@ def _load_env_files(cls):

@classmethod
def _collect_env_values(cls, struct_cls) -> dict[str, Any]:
"""Collect environment variable values for all fields.
"""Collect environment variable values for all fields."""
if cls.model_config.env_nested_delimiter is not None:
return cls._collect_nested_env_values(struct_cls)
return cls._collect_flat_env_values(struct_cls)

@classmethod
def _collect_flat_env_values(cls, struct_cls) -> dict[str, Any]:
"""Collect flat environment variable values for all fields.

Returns dict with field_name -> converted_value.
Highly optimized with cached field->env name mapping.
Expand Down Expand Up @@ -399,6 +463,100 @@ def _collect_env_values(cls, struct_cls) -> dict[str, Any]:

return env_dict

@classmethod
def _collect_nested_env_values(cls, struct_cls) -> dict[str, Any]:
"""Collect env values and unfold nested keys by delimiter."""
delimiter = cls.model_config.env_nested_delimiter
prefix = cls.model_config.env_prefix
case_sensitive = cls.model_config.case_sensitive
max_depth = cls.model_config.env_nested_max_depth

# Precompute prefix and delimiter for case-insensitive matching
if not case_sensitive:
prefix = prefix.upper()
delimiter = delimiter.upper()
prefix_len = len(prefix)

result = {}
struct_fields = set(struct_cls.__struct_fields__)

for env_key, env_value in os.environ.items():
# Normalize key for matching
key = env_key if case_sensitive else env_key.upper()

# Check and strip prefix
if prefix:
if not key.startswith(prefix):
continue
key = key[prefix_len:]

# Split by delimiter (respect max_depth)
if max_depth > 0:
parts = key.split(delimiter, maxsplit=max_depth)
else:
parts = key.split(delimiter)

# Normalize parts to field names (lowercase for case-insensitive)
if not case_sensitive:
parts = [p.lower() for p in parts]

# Skip if root field doesn't exist in struct
if parts[0] not in struct_fields:
continue

# Resolve leaf type and preprocess value
leaf_type = cls._resolve_leaf_type(struct_cls, parts)
if leaf_type is not None:
converted = cls._preprocess_env_value(env_value, leaf_type)
else:
converted = env_value

# Set value in nested dict
cls._set_nested_value(result, parts, converted)

return result

@staticmethod
def _set_nested_value(target: dict, parts: list[str], value: Any) -> None:
"""Set a value in a nested dict by key path."""
for part in parts[:-1]:
if part not in target or not isinstance(target[part], dict):
target[part] = {}
target = target[part]
target[parts[-1]] = value

@classmethod
def _resolve_leaf_type(cls, struct_cls, parts: list[str]) -> type | None:
"""Walk struct annotations to find the leaf field type."""
current = struct_cls
for i, part in enumerate(parts):
annotations = current.__annotations__
if part not in annotations:
return None
field_type = annotations[part]
if i == len(parts) - 1:
return field_type
# Navigate into nested struct
inner = cls._unwrap_struct_type(field_type)
if inner is not None and hasattr(inner, "__struct_fields__"):
current = inner
else:
return None
return None

@classmethod
def _unwrap_struct_type(cls, field_type) -> type | None:
"""Extract the core struct type, unwrapping Optional if needed."""
if isinstance(field_type, type) and hasattr(field_type, "__struct_fields__"):
return field_type
origin = get_origin(field_type)
if origin is Union or origin is types.UnionType:
args = get_args(field_type)
non_none = [a for a in args if a is not type(None)]
if len(non_none) == 1 and hasattr(non_none[0], "__struct_fields__"):
return non_none[0]
return None

@classmethod
def _get_env_name(cls, field_name: str) -> str:
"""Convert Python field name to environment variable name.
Expand Down
Loading
Loading