From f995050a0ad26019f821d6103b46cc656c6c5d32 Mon Sep 17 00:00:00 2001 From: 1fanwang <1fannnw@gmail.com> Date: Thu, 11 Jun 2026 10:50:05 -0700 Subject: [PATCH] fix: EnumTransformer.to_literal accepts a string matching an enum value Signed-off-by: 1fanwang <1fannnw@gmail.com> --- flytekit/core/type_engine.py | 7 ++++ tests/flytekit/unit/core/test_type_engine.py | 40 ++++++++++++++------ 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 9993c98479..1a953a9d07 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1081,6 +1081,13 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: def to_literal( self, ctx: FlyteContext, python_val: enum.Enum, python_type: Type[T], expected: LiteralType ) -> Literal: + # Accept a raw string that matches one of the enum's values. assert_type already + # allows this (e.g. an enum default supplied as a string), so to_literal must too, + # otherwise such a value passes type-checking but fails serialization. + if isinstance(python_val, str): + if python_val not in [item.value for item in python_type]: + raise TypeTransformerFailedError(f"Value {python_val} is not in Enum {python_type}") + return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val))) # type: ignore if type(python_val).__class__ != enum.EnumMeta: raise TypeTransformerFailedError("Expected an enum") if type(python_val.value) != str: diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 8945ea46dd..11ff02ea31 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -6,27 +6,26 @@ import sys import tempfile import typing +from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict, dataclass, field from datetime import timedelta from enum import Enum, auto -from typing import List, Optional, Type, Dict +from typing import Dict, List, Optional, Type import mock import msgpack import pytest import typing_extensions -from concurrent.futures import ThreadPoolExecutor from dataclasses_json import DataClassJsonMixin, dataclass_json from flyteidl.core import errors_pb2 from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct -from marshmallow_enum import LoadDumpOptions from marshmallow_jsonschema import JSONSchema from mashumaro.config import BaseConfig from mashumaro.mixins.json import DataClassJSONMixin from mashumaro.mixins.orjson import DataClassORJSONMixin from mashumaro.types import Discriminator -from typing_extensions import Annotated, get_args, get_origin +from typing_extensions import Annotated, get_args from flytekit import dynamic, kwtypes, task, workflow from flytekit.core.annotation import FlyteAnnotation @@ -34,16 +33,16 @@ from flytekit.core.data_persistence import flyte_tmp_dir from flytekit.core.hash import HashMethod from flytekit.core.type_engine import ( - IntTransformer, - FloatTransformer, BoolTransformer, - StrTransformer, DataclassTransformer, DictTransformer, EnumTransformer, + FloatTransformer, + IntTransformer, ListTransformer, LiteralsResolver, SimpleTransformer, + StrTransformer, TypeEngine, TypeTransformer, TypeTransformerFailedError, @@ -68,7 +67,7 @@ LiteralOffloadedMetadata, Primitive, Scalar, - Void, Binary, + Void, ) from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType from flytekit.types.directory import TensorboardLogs @@ -79,11 +78,11 @@ from flytekit.types.file import FileExt, JPEGImageFile from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer, noop from flytekit.types.iterator.iterator import IteratorTransformer -from flytekit.types.iterator.json_iterator import JSONIterator, JSONIteratorTransformer, JSON +from flytekit.types.iterator.json_iterator import JSON, JSONIterator, JSONIteratorTransformer from flytekit.types.pickle import FlytePickle from flytekit.types.pickle.pickle import FlytePickleTransformer from flytekit.types.schema import FlyteSchema -from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine, PARQUET +from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine T = typing.TypeVar("T") @@ -1532,6 +1531,20 @@ def test_enum_type(): TypeEngine.to_literal_type(UnsupportedEnumValues) +def test_enum_to_literal_accepts_matching_string(): + # A string matching an enum value is accepted by assert_type, so to_literal must + # accept it too (e.g. an enum default supplied as a string). A non-matching string + # is still rejected. + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(Color) + + lv = TypeEngine.to_literal(ctx, "red", Color, lt) + assert lv.scalar.primitive.string_value == "red" + + with pytest.raises(TypeTransformerFailedError): + TypeEngine.to_literal(ctx, "purple", Color, lt) + + def test_multi_inheritance_enum_type(): tfm = TypeEngine.get_transformer(MultiInheritanceColor) assert isinstance(tfm, EnumTransformer) @@ -4112,8 +4125,9 @@ def test_asyncio_wait_empty_kwargs_regression(): "Set of Tasks/Futures is empty" ValueError. """ import asyncio - from flytekit.models import literals as _literal_models + from flytekit.core.type_engine import TypeEngine + from flytekit.models import literals as _literal_models async def simulate_original_bug(): """ @@ -4153,8 +4167,8 @@ def test_error_message_improvements_literal_map_to_kwargs(): Test that error messages in literal_map_to_kwargs use proper repr formatting for better debugging experience. """ - from flytekit.models import literals as _literal_models from flytekit.core.type_engine import TypeTransformerFailedError + from flytekit.models import literals as _literal_models ctx = FlyteContextManager.current_context() @@ -4192,6 +4206,7 @@ def test_error_message_improvements_union_transformer(): for better debugging when conversion fails. """ from typing import Union + from flytekit.models import literals as _literal_models ctx = FlyteContextManager.current_context() @@ -4229,6 +4244,7 @@ def test_debug_logging_union_transformer(caplog): """ import logging from typing import Union + from flytekit.models import literals as _literal_models # Set logging level to capture debug messages