Skip to content

Add support for _expected_input_spec in MLX Lambda to match JAX implementation#17

Open
lancelotblanchard wants to merge 1 commit into
google:mainfrom
lancelotblanchard:mlx-lambda
Open

Add support for _expected_input_spec in MLX Lambda to match JAX implementation#17
lancelotblanchard wants to merge 1 commit into
google:mainfrom
lancelotblanchard:mlx-lambda

Conversation

@lancelotblanchard

Copy link
Copy Markdown

This PR implements expected_input_spec support in the MLX Lambda layer to match the JAX implementation. Previously, expected_input_spec was accepted in Lambda.Config for compatibility but was completely ignored by the MLX runtime.

The MLX dynamic shape probing (_probe_output) currently supports a input_dtype argument, but it is ignored by get_output_shape, which assumes mx.float32. The input_dtype argument is now overwritten by the config's expected_input_spec.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant