ストリーミングはほとんどのブラウザと
Developerアプリで視聴できます。
-
AppleシリコンでのMLXの導入
MLXは、Appleシリコン上で数値計算や機械学習を行うための柔軟で効率的な配列フレームワークです。ユニファイドメモリ、遅延計算、関数変換など、MLXの基本的な機能を解説します。また、SwiftとPythonのAPIを使用して、各種のAppleプラットフォーム上で機械学習モデルを構築、高速化するためのより高度な手法についても紹介します。
関連する章
- 0:00 - Introduction
- 1:15 - MLX overview
- 4:21 - Key features
- 10:15 - Accelerating MLX
- 17:30 - MLX Swift
リソース
- MLX
- MLX LM - Python API
- MLX Examples
- MLX Explore - Python API
- MLX Framework
- MLX Llama Inference
- MLX Swift
- MLX Swift Examples
関連ビデオ
WWDC25
-
このビデオを検索
Hi, I’m Awni. Today I’m excited to introduce you to MLX. MLX is an open source array framework purpose-built for Apple Silicon. It’s a flexible tool that can be used for basic numerical computations all the way to running the largest scale frontier machine learning models on Apple devices. If you want to generate text with large language models, generate images, audio, or even video with the latest models, MLX is for you. You can also use it to train, fine-tune, or otherwise customize your machine learning models directly on your Mac. I’ll start by telling you a bit more about what MLX is and where it’s useful. I’ll also cover the basics of using MLX in Python, including installation and basic array operations. After that, I’ll tell you about some of the key features that set MLX apart from other frameworks.
Next, I’ll walk through some of the tools MLX has to make your machine learning workloads as fast as possible on Apple Silicon.
In the last section, I’ll give you a brief look at MLX Swift and show you how to get started using it.
Let's dive into an introduction to MLX. MLX was built from the ground up to be fast on Apple Silicon, where it can run on the CPU or accelerated on the GPU. You can use MLX for a variety of applications, from small-scale numerical computing to large-scale machine learning.
It’s designed to be easy to use and flexible, without compromising speed and efficiency.
MLX has a core API which closely follows NumPy. You can often use it as a drop-in replacement to accelerate most numerical computations. MLX also has all the tools you need for machine learning, including automatic differentiation and higher level libraries. These higher level APIs are similar to PyTorch and JAX. If you’re coming from any of those frameworks, MLX will be familiar and even easier to get started with.
You can use MLX for more advanced machine learning directly on device.
For example, MLX is used by LM Studio, a popular application for generating text with large language models directly on your Mac.
You can use the MLX LM package, which is built on top of MLX, to generate text and fine-tune language models up to hundreds of billions of parameters in size. To learn more about that, check out the session, “Explore large language models on Apple silicon with MLX”. MLX has a fully-featured Python API, which is useful for rapid prototyping. It also has an API in Swift, which includes all the high-level packages for building and optimizing neural networks.
MLX also has APIs in C++ and C. You can use MLX in any of these languages to run the latest machine learning models on Apple Silicon, including Mac, iPhone, iPad, and Vision Pro.
All of the MLX software is open source under a permissive MIT license. The core software is available on GitHub, along with several examples and packages built using the Python and Swift APIs. MLX also has an active community of model creators on Hugging Face. Many of the latest models are already in the MLX community Hugging Face organization, and new models are uploaded every day.
The easiest way to get started with MLX in Python is to install it from PyPi. You only need to run one line in your terminal, pip3 install mlx.
MLX is easy to use. To start performing computations on arrays, simply open up a Python file and import MLX. Then you can make some arrays and do basic operations on them. For example, here we’re adding two integer arrays.
You can also easily inspect information about an array, such as its shape and data type.
As I said before, the MLX Python API is similar to NumPy. The operations usually have the same names and signatures and behave the same way. If you are coming from NumPy or a similar framework, MLX will be familiar and easy to get started with. Now that you know what MLX is, where it’s useful, and some of the basics, let’s learn about the key features. These include unified memory, lazy evaluation, function transformations, and some higher-level packages for building and optimizing neural networks.
MLX is designed to take advantage of the best of Apple Silicon. This includes a new programming model specific to unified memory.
Most systems commonly used for machine learning have a discrete GPU with separate memory. Apple Silicon, on the other hand, has a unified memory architecture. This means that the CPU and the GPU share the same physical memory.
To work with unified memory, MLX is different from what you may be used to in traditional frameworks.
In traditional frameworks, computation follows data. If the array is in CPU memory, the computation happens on the CPU. If the array is in GPU memory, the computation happens on the GPU.
In MLX, arrays are allocated in unified memory. You never need to copy them anywhere to use them on any of the supported devices.
Instead, to run an operation on a device, you specify the device to the operation itself.
For example, here we’re adding a and b on the GPU and multiplying a and b on the CPU.
These operations can even run in parallel, and MLX will automatically manage dependencies when they exist. Another key feature of MLX is lazy computation. To make MLX as efficient as possible, Especially for very large computations, MLX has a lazy execution engine.
When an operation like an addition on two arrays is called, no actual computation happens. Instead, a computation graph is built, like the one you see here.
The array C is not yet computed. It only gets computed if you actually need to use it. For example, printing C or converting C out of MLX and into a Python list will force it to be computed.
You can explicitly force the graph to be evaluated using mx.eval.
Lazy computation has several benefits. By decoupling the building and execution of the computation graph, MLX can do transformations and optimizations on the graph before computing the results. Also, with lazy computation, you only ever pay for what you use. Function transformations are another key feature of MLX. They elevate it from a fast array framework to a much more powerful tool for training and improving the performance of machine learning models.
Function transformations are transformations which take functions as input and return new functions as output.
MLX has several function transformations. They typically fall into one of two categories. Transformations for automatic differentiation are transformations to optimize the compute graph. For example, you can use a function transformation to automatically compute the gradient of any function in MLX.
Suppose you have a basic function which computes the sine of its input. To take the gradient of this function, you can use the mx.grad function transformation.
The output of mx.grad is a new function, which, when you call it on an array, gives you the derivative. Function transformations are arbitrarily composable. You can take the second derivative by simply using mx.grad on the output of mx.grad.
The result is another function, which, when called on an array, gives you the second derivative of the sign.
MLX also has two higher-level packages to make building and training neural networks easy.
The first is mlx.nn, a modular library used to build neural networks.
The second is mlx.optimizers, a library of common optimization algorithms. The two packages can be used standalone, but also seamlessly integrate with one another.
The mlx.nn package has all the functionality you need to build neural networks. nn.module is the primary base class that all layers and containers inherit from. It exposes helpful methods for accessing parameters, loading parameters, saving parameters, and more.
The nn library also comes with a set of standard off-the-shelf layers, like nn.Linear, but you can also build your own layers by inheriting from nn.Module.
Commonly used loss functions and initialization methods are also included in the nn.losses and nn.init subpackages.
Let’s take a look at how you can build a simple multilayer neural network with mlx.nn.
The first step is to make a custom class that inherits from nn.Module. In this case, we’ll use a simple one hidden layer neural network. We create the linear layers inside the initialization method of the module. Then we implement the call function, which computes the output of the module on the given input. The call function calls the first linear layer, applies the relu activation function, and then calls the second linear layer. While MLX is designed and optimized for Apple Silicon’s unified memory architecture, to make it easy to get started with, the higher level neural network package is also designed to be similar to commonly used machine learning frameworks, like PyTorch.
Let’s compare the MLX model implementation to the same implementation in PyTorch. They're almost identical, with only two small differences in the function which computes the output.
If you've built models in PyTorch before, then switching to MLX should be very straightforward. Now that you’ve seen most of the core features of MLX, let’s talk about how to use it to make your machine learning workloads as fast and efficient as possible. I'll start by showing you how to compile functions to speed them up. Then I’ll tell you about the mx.fast sub-package, which has off-the-shelf fast implementations of common machine learning operations, and an API for adding your own custom Metal Kernels.
After that, I’ll show you how to use quantization to reduce memory use and speed up model execution. Lastly, I'll show you how to use MLX to distribute a computation across many machines.
Almost every realistic computation MLX will consist of functions which perform several operations on arrays. A simple way to make functions like that faster is with the mx.compile function transformation.
Suppose you have a function which does a few element-wise operations, like the GELU activation function shown here.
The computation graph for this function has several nodes in it. Each of these nodes corresponds to a separate GPU Kernel launch under the hood.
Compiling the graph uses all of these separate Kernels into single fused Kernel. This can save memory bandwidth and graph execution overhead and result in a much more efficient computation.
Using mx.compile is as easy as decorating the function with the mx.compile function transformation.
Compilation often works well, but for more complex operations, including some common operations in machine learning, it’s likely they can run faster using the mx.fast sub package. For example, many of the core building blocks of a transformer model use operations in mx.fast. These include the positional encodings, normalization layers, and the scale dot product attention. The operations in mx.fast are more specialized, but highly tuned to be as fast as possible for both training and inference. They are also highly configurable, so they can accelerate many variations of a given computation. For example, the scale dot product attention operation can take an optional mask as input. The mask can be an additive mask, a Boolean mask, or a string indicating the mask type. Let’s take a closer look at RMS norm, which is one of the operations in mx.fast.
RMS norm is used in nearly every modern transformer-based large language model.
A basic implementation using MLX operations results in a large computation graph.
Instead, we can replace the entire implementation with a single call to mx.fast.rms_norm. The code is simpler, the computation graph only has a single node, and the computation itself will be much faster.
MLX has an API for adding custom Metal Kernels for cases where your function could benefit from a more customized implementation, and it’s not already in mx.fast.
You write the custom Kernel and MLX handles all the rest, including just-in-time compilation and execution. These Kernels are written in Metal, which is a language and API developed by Apple to run functions on Apple GPUs.
You build the Kernel by passing in a source string of Metal code, as well as some information about the inputs and outputs.
You call the Kernel by specifying the grid size and the shapes and types of the output. MLX treats the Kernel call the same way as any other operation. It creates a node in the computation graph and is lazily evaluated.
Another tool in our toolkit for making your machine learning workloads faster is quantization.
Large models need lots of memory and lots of memory bandwidth to be fast. In many cases, the precision you need for training is much higher than what you actually need for inference to get the same or almost as good quality. Reducing the precision lets you fit larger models in memory and run them faster.
If your model is in 32-bit floating point precision, you can use bfloat16 or float16 as a first step to reduce the memory requirements by half. When 16-bits is too many, MLX has built-in routines for quantizing arrays to be even smaller and doing operations with them.
For example, you can quantize using 4-bits per element to shrink memory requirements even further.
To quantize a matrix, use mx.quantize. You specify the number of bits to use per element and the group size. The group size determines the number of elements in the quantized matrix that get a shared scale and bias value. The smaller the bits and the larger the group size, the smaller and faster the result. MLX has several options for bits and group sizes to give you as much flexibility as possible.
You can multiply any un-quantized input vector or matrix by your quantized matrix using mx.quantized_matmul.
Use mx.dequantize to recover an approximation to the original input.
MLX.mn also has a handy utility to quantize a module in a single command. Let’s say you have a model that is a stack of an embedding layer followed by several linear layers.
you can quantize the model with nn.quantize. The quantize command also takes an optional callback to let you have more fine-grained control over which layers to quantize and what precision to use for a given layer.
When generating text with large language models, quantization dramatically reduces the memory use and increases the tokens generated per second.
In some cases, one machine is simply not enough. For example, you might want to generate text with a large model that doesn’t fit in the memory of a single machine.
Or you might be fine tuning or evaluating a model on a large data set, both of which are easy to parallelize and can be much faster using multiple machines.
MLX comes out of the box with the ability to run arbitrary computations on multiple machines. The machines can be connected over ethernet or Thunderbolt.
You use the mx.distributed sub-package to distribute computations across multiple machines.
MX distributed is mostly a set of communication operations.
For example, all_Sum adds the input arrays across all machines. The output of all_Sum is the summed up input, which is the same for each machine.
Let’s take a closer look at how to sum an array over multiple machines. Initialize the distributed backend using mx.distributed.init. This step is optional, but it’s what you call if you need access to the communication group.
The communication group has useful information attached to it, like the total number of processes and the current process index.
Then make an array with a single value on each process and call mx.distributed.all_sum to sum the arrays across all processes.
MLX has a handy launcher for running an MLX program on multiple machines. To run a program on 4 machines, use mlx.launch with the 4 host IP addresses. Everything I’ve gone over so far has been using MLX in Python.
In many cases, you may prefer the ease and flexibility of Python. In other cases, you may prefer Swift. For that reason, MLX has a fully featured API in Swift.
It is built on top of Metal and works great across macOS, iOS, iPadOS, visionOS, and more.
Getting started with MLX in Swift is as easy as adding it as a package to your Xcode project. Click on the project and then click on the plus sign in the package dependencies tab.
Then enter the URL for the MLX Swift GitHub repository and click the add package button.
That’s all it takes to start building with MLX Swift.
To make it as easy as possible to transition between Python and Swift, the APIs between the two languages are intentionally similar. Here’s a side-by-side comparison of a Python code snippet we saw earlier with its MLX Swift counterpart.
making arrays, performing operations on them, and inspecting metadata about the array is almost the same in Swift as in Python. All of the core features we went through using MLX in Python, as well as the optimizations we discussed in the accelerating MLX section, apply equally to MLX Swift. We've gone through a broad overview of many of the key features of MLX. To learn more about the framework, check out the MLX website, which has links to the documentation, examples, and more. Both the Python and Swift APIs have an examples repo, which contain many examples of common machine learning use cases, including language model training and generation, image generation, speech recognition, and many more. These examples are a great way to learn more about MLX, and good starting points to build with MLX in your own project. Thank you for watching, and we’re excited to see what you build with MLX.
-
-
3:48 - Basics
import mlx.core as mx # Make an array a = mx.array([1, 2, 3]) # Make another array b = mx.array([4, 5, 6]) # Do an operation c = a + b # Access information about the array shape = c.shape dtype = c.dtype print(f"Result c: {c}") print(f"Shape: {shape}") print(f"Data type: {dtype}")
-
5:31 - Unified memory
import mlx.core as mx a = mx.array([1, 2, 3]) b = mx.array([4, 5, 6]) c = mx.add(a, b, stream=mx.gpu) d = mx.multiply(a, b, stream=mx.cpu) print(f"c computed on the GPU: {c}") print(f"d computed on the CPU: {d}")
-
6:20 - Lazy computation
import mlx.core as mx # Make an array a = mx.array([1, 2, 3]) # Make another array b = mx.array([4, 5, 6]) # Do an operation c = a + b # Evaluates c before printing it print(c) # Also evaluates c c_list = c.tolist() # Also evaluates c mx.eval(c) print(f"Evaluate c by converting to list: {c_list}") print(f"Evaluate c using print: {c}") print(f"Evaluate c using mx.eval(): {c}")
-
7:32 - Function transformation
import mlx.core as mx def sin(x): return mx.sin(x) dfdx = mx.grad(sin) def sin(x): return mx.sin(x) d2fdx2 = mx.grad(mx.grad(mx.sin)) # Computes the second derivative of sine at 1.0 d2fdx2(mx.array(1.0))
-
9:16 - Neural Networks in MLX
import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim class MLP(nn.Module): """A simple MLP.""" def __init__(self, dim, h_dim): super().__init__() self.linear1 = nn.Linear(dim, h_dim) self.linear2 = nn.Linear(h_dim, dim) def __call__(self, x): x = self.linear1(x) x = nn.relu(x) x = self.linear2(x) return x
-
9:57 - MLX and PyTorch
# MLX version import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim class MLP(nn.Module): """A simple MLP.""" def __init__(self, dim, h_dim): super().__init__() self.linear1 = nn.Linear(dim, h_dim) self.linear2 = nn.Linear(h_dim, dim) def __call__(self, x): x = self.linear1(x) x = nn.relu(x) x = self.linear2(x) return x # PyTorch version import torch import torch.nn as nn import torch.optim as optim class MLP(nn.Module): """A simple MLP.""" def __init__(self, dim, h_dim): super().__init__() self.linear1 = nn.Linear(dim, h_dim) self.linear2 = nn.Linear(h_dim, dim) def forward(self, x): x = self.linear1(x) x = x.relu() x = self.linear2(x) return x
-
11:35 - Compiling MLX functions
import mlx.core as mx import math def gelu(x): return x * (1 + mx.erf(x / math.sqrt(2))) / 2 @mx.compile def compiled_gelu(x): return x * (1 + mx.erf(x / math.sqrt(2))) / 2 x = mx.random.normal(shape=(4,)) out = gelu(x) compiled_out = compiled_gelu(x) print(f"gelu: {out}") print(f"compiled gelu: {compiled_out}")
-
12:32 - MLX Fast package
import mlx.core as mx import time def rms_norm(x, weight, eps=1e-5): y = x.astype(mx.float32) y = y * mx.rsqrt(mx.mean( mx.square(y), axis=-1, keepdims=True, ) + eps) return (weight * y).astype(x.dtype) batch_size = 8192 feature_dim = 4096 iterations = 1000 x = mx.random.normal([batch_size, feature_dim]) weight = mx.ones(feature_dim) bias = mx.zeros(feature_dim) start_time = time.perf_counter() for _ in range(iterations): y = rms_norm(x, weight, eps=1e-5) mx.eval(y) rms_norm_time = time.perf_counter() - start_time print(f"rms_norm execution: {gelu_time:0.4f} sec") start_time = time.perf_counter() for _ in range(iterations): mx.eval(mx.fast.rms_norm(x, weight, eps=1e-5)) fast_rms_norm_time = time.perf_counter() - start_time print(f"mx.fast.rms_norm execution: {compiled_gelu_time:0.4f} sec") print(f"mx.fast.rms_norm speedup: {rms_norm_time/fast_rms_norm_time:0.2f}x")
-
13:30 - Custom Metal kernel
import mlx.core as mx # Build the kernel source = """ uint elem = thread_position_in_grid.x; out[elem] = metal::exp(inp[elem]); """ kernel = mx.fast.metal_kernel( name="myexp", input_names=["inp"], output_names=["out"], source=source, ) # Call the kernel on a sample input x = mx.array([1.0, 2.0, 3.0]) out = kernel( inputs=[x], grid=(x.size, 1, 1), threadgroup=(256, 1, 1), output_shapes=[x.shape], output_dtypes=[x.dtype], )[0] print(out)
-
14:41 - Quantization
import mlx.core as mx x = mx.random.normal([1024]) weight = mx.random.normal([1024, 1024]) quantized_weight, scales, biases = mx.quantize( weight, bits=4, group_size=32, ) y = mx.quantized_matmul( x, quantized_weight, scales=scales, biases=biases, bits=4, group_size=32, ) w_orig = mx.dequantize( quantized_weight, scales=scales, biases=biases, bits=4, group_size=32, )
-
15:23 - Quantized models
import mlx.nn as nn model = nn.Sequential( nn.Embedding(100, 32), nn.Linear(32, 32), nn.Linear(32, 32), nn.Linear(32, 1), ) print(model) nn.quantize( model, bits=4, group_size=32, ) print(model)
-
16:50 - Distributed
import mlx.core as mx group = mx.distributed.init() world_size = group.size() rank = group.rank() x = mx.array([1.0]) x_sum = mx.distributed.all_sum(x) print(x_sum)
-
17:20 - Distributed launcher
mlx.launch --hosts ip1, ip2, ip3, ip4 my_script.py
-
18:20 - MLX Swift
// Swift import MLX // Make an array let a = MLXArray([1, 2, 3]) // Make another array let b = MLXArray([1, 2, 3]) // Do an operation let c = a + b // Access information about the array let shape = c.shape let dtype = c.dtype // Print results print("a: \(a)") print("b: \(b)") print("c = a + b: \(c)") print("shape: \(shape)") print("dtype: \(dtype)")
-
-
- 0:00 - Introduction
MLX is an open-source array framework purpose-built for Apple silicon. It enables efficient machine learning tasks and allows you to run large language models directly on device using Python and Swift.
- 1:15 - MLX overview
This high-performance machine learning framework is optimized for Apple silicon, enabling fast numerical computing and machine learning tasks on CPU and GPU. It has a NumPy-like core API, and the higher-level API is similar to PyTorch and JAX. Use it in applications like LM Studio for on-device text generation with large language models. MLX has APIs in Python, Swift, C++, and C. MLX is open-source under MIT license, and has an active community on Hugging Face.
- 4:21 - Key features
MLX is incredibly efficient because it’s tailored for Apple silicon, leveraging its unified memory architecture. Sharing memory between the CPU and GPU eliminates the need for data copying; operations simply specify the desired device. Instead of immediately executing a computation, MLX builds computation graphs that are executed only when a result is needed. With function transformations, MLX can take functions as input and return new functions, facilitating automatic differentiation and other optimizations. MLX includes higher-level packages for building and training neural networks, as well as common machine learning operations. These packages are modular and designed to be similar to popular frameworks like PyTorch, making it easy for developers to switch.
- 10:15 - Accelerating MLX
Function compilations with 'mx.compile' fuse multiple GPU kernel launches into a single kernel, reducing memory bandwidth and execution overhead. For more complex operations, the 'mx.fast' sub-package provides highly tuned, specialized implementations of common machine learning operations, such as RMS norm and attention mechanisms. MLX enables quantization for faster and more efficient inference with less memory overhead, reducing precision without overly impacting quality. Large-scale computations can utilize the 'mx.distributed' sub-package to distribute work across multiple machines.
- 17:30 - MLX Swift
MLX also offers a Swift API that offers the same improvements in efficiency, for seamless development across Apple platforms in Xcode. Visit the MLX website or download the sample repos to get started and learn more.