$ python3 import torch from torch import nn # simple linear model class HelloNNWorld(nn.Module): def __init__(self, H=28, W=28, C=10): super(HelloNNWorld, self).__init__() self.linear = nn.Linear(H*W, C) def forward(self, x): x = torch.flatten(x, start_dim=1) x = self.linear(x) return nn.functional.relu(x) def generate_data(b): return (torch.randn(b,28,28).to(torch.float32), torch.randint(10, (b,))) model = HelloNNWorld() import torch_mlir import torch._dynamo as dynamo from torch_mlir.dynamo import make_simple_dynamo_backend def torch_mlir_backend(gm, example_inputs): print("FX Graph:") print(gm) print("Torch MLIR backend:") mlir_module = torch_mlir.compile(gm, example_inputs, output_type="torch") print(mlir_module) return gm mlir_model = dynamo.optimize(make_simple_dynamo_backend(torch_mlir_backend))(model) mlir_model(generate_data(16)[0])