本文翻译整理自:https://github.com/google/flax
文章目录
一、关于 Flax
Flax NNX 是2024年发布的全新简化版Flax API,旨在简化在JAX中创建、检查、调试和分析神经网络的过程。
它通过原生支持Python引用语义来实现这一目标,允许用户使用常规Python对象表达模型,支持引用共享和可变性。
Flax NNX 由Flax Linen API演进而来,后者是2020年由Google Brain工程师和研究人员与JAX团队紧密合作发布的。
您可以在Flax专属文档站点了解更多关于Flax NNX的信息,特别推荐以下内容:
注:Flax Linen的文档有独立站点。
Flax团队的使命是服务于不断增长的JAX神经网络研究生态系统——包括Alphabet内部和更广泛的社区,并探索JAX的闪光用例。我们几乎所有的协调和规划工作都在GitHub上进行,包括讨论即将进行的设计变更。欢迎您在我们的讨论区、问题区和拉取请求线程中提供反馈。
您可以在Flax GitHub讨论区提出功能请求、分享工作内容、报告问题或提问。
我们期待改进Flax,但不预期核心API会有重大破坏性变更。我们会尽可能使用更新日志和弃用警告。
如需直接联系我们,请发邮件至flax-dev@google.com。
相关链接资源
- github : https://github.com/google/flax
- 官网:https://flax.readthedocs.io/
- 官方文档:https://flax.readthedocs.io/
- Demo/在线试用:https://flax.readthedocs.io/en/latest/mnist_tutorial.html
- CodeDEV : https://codecov.io/gh/google/flax
- Community : https://github.com/google/flax/discussions
- Blog : https://flax.readthedocs.io/en/latest/guides/index.html
- FAQ : https://github.com/google/flax/issues
- License : Apache License 2.0
关键功能特性
神经网络API (
flax.nnx
):包含Linear
、Conv
、BatchNorm
、LayerNorm
、GroupNorm
、Attention (MultiHeadAttention
)、LSTMCell
、GRUCell
、Dropout
实用工具和模式:复制训练、序列化和检查点、指标、设备预取
二、安装
Flax基于JAX,请先查看JAX在CPU、GPU和TPU上的安装说明。
需要Python 3.8或更高版本。从PyPi安装Flax:
pip install flax
升级到最新版Flax:
pip install --upgrade git+https://github.com/google/flax.git
安装额外依赖(如matplotlib
):
pip install "flax[all]"
三、Flax代码示例
我们提供三个使用Flax API的示例:简单多层感知机、CNN和自动编码器。
要了解Module
抽象,请查阅我们的文档和Module抽象介绍。更多最佳实践示例,请参考我们的指南和开发者笔记。
多层感知机示例:
class MLP(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
self.linear1 = Linear(din, dmid, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.linear2 = Linear(dmid, dout, rngs=rngs)
def __call__(self, x: jax.Array):
x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
return self.linear2(x)
CNN示例:
class CNN(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
def __call__(self, x):
x = self.avg_pool(nnx.relu(self.conv1(x)))
x = self.avg_pool(nnx.relu(self.conv2(x)))
x = x.reshape(x.shape[0], -1) # flatten
x = nnx.relu(self.linear1(x))
x = self.linear2(x)
return x
自动编码器示例:
Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs)
Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs)
class AutoEncoder(nnx.Module):
def __init__(self, rngs):
self.encoder = Encoder(rngs)
self.decoder = Decoder(rngs)
def __call__(self, x) -> jax.Array:
return self.decoder(self.encoder(x))
def encode(self, x) -> jax.Array:
return self.encoder(x)
伊织 xAI 2025-04-27(日)