Comments (4)
Hey @polvalente, could you please elaborate your use case a little bit?
- How do you get your StableHLO module in the first place? (By training a JAX model I guess?)
- What does the TF graph look like? Does it fully digest the StableHLO module, or is
XlaCallModule
a high level op that still delegates to the StableHLO graph? - Could you please share the original error message?
- Could you please provide a minimal reproduce?
from coremltools.
@YifanShenSZ Thanks for the response!
I'll preface this by saying that my end goal is to have some way to compile https://github.com/elixir-nx/nx into CoreML for running on iOS devices. I tried StableHLO because that seemed to be a good/ready pathway.
- It's close to how you would get a Jax model, but I'm actually using https://github.com/elixir-nx/nx/tree/main/exla, which uses openxla to compile Elixir (through Nx) code into a StableHLO module.
XlaCallModule
is a high level op that contains/delegates to the StableHLO graph. I unfortunately wasn't able to find a way to fully digest a stablehlo module into a TF graph.
3 and 4. I'll get back to you shortly with a reproduction script and sample output.
from coremltools.
I see, since XlaCallModule
is a delegation, unfortunately we won't be able to convert it as is.
However, I do think if you can let TF graph digest the StableHLO module, then we may have a chance to convert. One potential way to do so is to mimic JAX2TF, although you may need to dig into how MLIR is called under the hood
from coremltools.
@YifanShenSZ Unfortunately I think JAX2TF does exactly what I'm doing. Thanks for the suggestion though!
Here's the reproduction script and sample output files anyway.
import tensorflow as tf
from tensorflow.compiler.tf2xla.python import xla as tfxla
import coremltools as ct
stablehlo_module = """
module {
func.func public @main(%arg0: tensor<2xi64>, %arg1: tensor<2xi64>) -> tensor<2xi64> {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<2xi64>) -> tensor<2xi64>
%2 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<2xi64>) -> tensor<2xi64>
%3 = stablehlo.multiply %1, %2 : tensor<2xi64>
%4 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<2xi64>) -> tensor<2xi64>
%5 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<2xi64>) -> tensor<2xi64>
%6 = stablehlo.subtract %4, %5 : tensor<2xi64>
%7 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<2xi64>) -> tensor<2xi64>
%8 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi64>) -> tensor<2xi64>
%9 = stablehlo.add %7, %8 : tensor<2xi64>
stablehlo.return %9 : tensor<2xi64>
}
}
"""
# Define the computation using tf.function
@tf.function(input_signature=[
tf.TensorSpec(shape=[2], dtype=tf.int64, name='arg0'),
tf.TensorSpec(shape=[2], dtype=tf.int64, name='arg1')
])
def model_fn(arg0, arg1):
args = [arg0, arg1]
res = tfxla.call_module(
args,
module=stablehlo_module,
Tout=[tf.int64],
Sout=[(2,)],
function_list=[],
platforms=[],
)
return res
# Convert the SavedModel to Core ML
coreml_model = ct.convert(model_fn.get_concrete_function().graph.as_graph_def(), source='tensorflow')
I am able to run the model_fn
function as if it were a standard TF model, so there's that at least :)
Feel free to close this is you do feel there's no path forward here!
from coremltools.
Related Issues (20)
- [ExecuTorch] Cannot Convert Empty Model
- [ExecuTorch] IndexPut Fails on NeuralNetwork Backend
- feature: open-source MPS HOT 1
- [torch.jit.trace] Inplace Index Put Silent Error
- [ExecuTorch] Cannot Use Dynamic Index to Select HOT 1
- CoreML pip package drift HOT 2
- pytorch model is transformed into coreml model expect a tensor of type integer, but actually receive a tensor of type floating point. HOT 1
- The utilization rate of GPU is only about 2% in Mac HOT 1
- Support an `aarch64` wheel HOT 4
- numpy_t model op is missing when converting a model using M.T transpose notation HOT 9
- support for torch.cdist HOT 4
- [ExecuTorch] Placeholder Assertion Error HOT 1
- OpRegistry Error when defining MIL Program HOT 1
- Quantization for mlpackage model HOT 1
- tf.random.uniform is optimized out when it shouldn't be HOT 2
- MLUpdateTask not working for HOT 1
- MFCC differences between torchaudio and CoreML HOT 1
- Issue converting OpenELM: Conversion for torch.repeat_interleave with non-zero dim has not been implemented HOT 1
- Is there support to run ANE accelerated loops/while_loop? HOT 1
- Flexible input for MelSpectrogram: 'complex_shape' object has no attribute 'data'
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from coremltools.