Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,13 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
def to_literal(
self, ctx: FlyteContext, python_val: enum.Enum, python_type: Type[T], expected: LiteralType
) -> Literal:
# Accept a raw string that matches one of the enum's values. assert_type already
# allows this (e.g. an enum default supplied as a string), so to_literal must too,
# otherwise such a value passes type-checking but fails serialization.
if isinstance(python_val, str):
if python_val not in [item.value for item in python_type]:
raise TypeTransformerFailedError(f"Value {python_val} is not in Enum {python_type}")
return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val))) # type: ignore
if type(python_val).__class__ != enum.EnumMeta:
raise TypeTransformerFailedError("Expected an enum")
if type(python_val.value) != str:
Expand Down
40 changes: 28 additions & 12 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,43 @@
import sys
import tempfile
import typing
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict, dataclass, field
from datetime import timedelta
from enum import Enum, auto
from typing import List, Optional, Type, Dict
from typing import Dict, List, Optional, Type

import mock
import msgpack
import pytest
import typing_extensions
from concurrent.futures import ThreadPoolExecutor
from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import errors_pb2
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
from marshmallow_enum import LoadDumpOptions
from marshmallow_jsonschema import JSONSchema
from mashumaro.config import BaseConfig
from mashumaro.mixins.json import DataClassJSONMixin
from mashumaro.mixins.orjson import DataClassORJSONMixin
from mashumaro.types import Discriminator
from typing_extensions import Annotated, get_args, get_origin
from typing_extensions import Annotated, get_args

from flytekit import dynamic, kwtypes, task, workflow
from flytekit.core.annotation import FlyteAnnotation
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import flyte_tmp_dir
from flytekit.core.hash import HashMethod
from flytekit.core.type_engine import (
IntTransformer,
FloatTransformer,
BoolTransformer,
StrTransformer,
DataclassTransformer,
DictTransformer,
EnumTransformer,
FloatTransformer,
IntTransformer,
ListTransformer,
LiteralsResolver,
SimpleTransformer,
StrTransformer,
TypeEngine,
TypeTransformer,
TypeTransformerFailedError,
Expand All @@ -68,7 +67,7 @@
LiteralOffloadedMetadata,
Primitive,
Scalar,
Void, Binary,
Void,
)
from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType
from flytekit.types.directory import TensorboardLogs
Expand All @@ -79,11 +78,11 @@
from flytekit.types.file import FileExt, JPEGImageFile
from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer, noop
from flytekit.types.iterator.iterator import IteratorTransformer
from flytekit.types.iterator.json_iterator import JSONIterator, JSONIteratorTransformer, JSON
from flytekit.types.iterator.json_iterator import JSON, JSONIterator, JSONIteratorTransformer
from flytekit.types.pickle import FlytePickle
from flytekit.types.pickle.pickle import FlytePickleTransformer
from flytekit.types.schema import FlyteSchema
from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine, PARQUET
from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine

T = typing.TypeVar("T")

Expand Down Expand Up @@ -1532,6 +1531,20 @@ def test_enum_type():
TypeEngine.to_literal_type(UnsupportedEnumValues)


def test_enum_to_literal_accepts_matching_string():
# A string matching an enum value is accepted by assert_type, so to_literal must
# accept it too (e.g. an enum default supplied as a string). A non-matching string
# is still rejected.
ctx = FlyteContextManager.current_context()
lt = TypeEngine.to_literal_type(Color)

lv = TypeEngine.to_literal(ctx, "red", Color, lt)
assert lv.scalar.primitive.string_value == "red"

with pytest.raises(TypeTransformerFailedError):
TypeEngine.to_literal(ctx, "purple", Color, lt)


def test_multi_inheritance_enum_type():
tfm = TypeEngine.get_transformer(MultiInheritanceColor)
assert isinstance(tfm, EnumTransformer)
Expand Down Expand Up @@ -4112,8 +4125,9 @@ def test_asyncio_wait_empty_kwargs_regression():
"Set of Tasks/Futures is empty" ValueError.
"""
import asyncio
from flytekit.models import literals as _literal_models

from flytekit.core.type_engine import TypeEngine
from flytekit.models import literals as _literal_models

async def simulate_original_bug():
"""
Expand Down Expand Up @@ -4153,8 +4167,8 @@ def test_error_message_improvements_literal_map_to_kwargs():
Test that error messages in literal_map_to_kwargs use proper repr formatting
for better debugging experience.
"""
from flytekit.models import literals as _literal_models
from flytekit.core.type_engine import TypeTransformerFailedError
from flytekit.models import literals as _literal_models

ctx = FlyteContextManager.current_context()

Expand Down Expand Up @@ -4192,6 +4206,7 @@ def test_error_message_improvements_union_transformer():
for better debugging when conversion fails.
"""
from typing import Union

from flytekit.models import literals as _literal_models

ctx = FlyteContextManager.current_context()
Expand Down Expand Up @@ -4229,6 +4244,7 @@ def test_debug_logging_union_transformer(caplog):
"""
import logging
from typing import Union

from flytekit.models import literals as _literal_models

# Set logging level to capture debug messages
Expand Down
Loading