«

基于JAX-Privacy的大规模差分隐私机器学习

qimuai 发布于 阅读:50 一手编译


基于JAX-Privacy的大规模差分隐私机器学习

内容来源:https://research.google/blog/differentially-private-machine-learning-at-scale-with-jax-privacy/

内容总结:

近日,谷歌深度思维与谷歌研究院联合发布了JAX-Privacy 1.0版本。这款基于高性能计算库JAX的差分隐私机器学习工具包,旨在帮助开发者在保障数据隐私的前提下,高效训练大规模人工智能模型。

随着人工智能技术在个性化推荐、科学研究等领域的深入应用,高质量数据集已成为提升模型性能的关键。然而,如何在充分利用数据的同时保护用户隐私,始终是行业面临的重大挑战。差分隐私技术作为隐私保护的黄金标准,能够通过数学方法确保模型输出不受单个数据点影响,但将其应用于大规模机器学习时,常面临计算复杂度和实现效率的难题。

JAX-Privacy的诞生正是为了突破这一瓶颈。该工具包基于谷歌2020年推出的高性能数值计算库JAX构建,充分利用其自动微分、即时编译和多加速器无缝扩展等特性,为差分隐私训练提供三大核心能力:一是提供梯度裁剪、噪声注入等基础隐私保护组件的标准化实现;二是集成矩阵分解等前沿隐私算法,支持相关噪声注入等优化技术;三是通过微批处理和填充工具,实现大规模可变批量数据的高效处理。

值得关注的是,该工具包已成功应用于VaultGemma大模型的训练——这是目前全球性能最强的差分隐私大语言模型。此次开源版本还提供了完整的实践案例,开发者仅需数行代码即可通过Keras框架对Gemma系列模型进行隐私保护下的微调,适用于对话摘要、数据生成等实际场景。

业界专家认为,JAX-Privacy的发布将显著降低隐私保护机器学习的技术门槛,为医疗健康、金融服务等敏感数据应用领域提供可靠的技术支撑。该工具包现已开放GitHub代码库和PIP安装包,预计将推动隐私保护人工智能技术进入新的发展阶段。

中文翻译:

利用JAX-Privacy实现规模化差分隐私机器学习
2025年11月12日
Borja Balle(谷歌DeepMind主任研究员)与Ryan McKenna(谷歌研究院高级研究员)

我们正式发布JAX-Privacy 1.0——这是一个基于高性能计算库JAX的差分隐私机器学习工具库。
[快速链接]

从个性化推荐到科学突破,人工智能模型正在助力改善人类生活并推动产业变革。但AI模型的影响力与准确性往往取决于所用数据的质量。大规模高质量数据集对开发精准且具代表性的AI模型至关重要,然而这些数据的使用方式必须确保个人隐私不受侵犯。

这正是JAX与JAX-Privacy的价值所在。JAX自2020年推出以来,始终是专为大规模机器学习设计的高性能数值计算库。其核心功能——包括自动微分、即时编译及多加速器无缝扩展能力——使其成为高效构建和训练复杂模型的理想平台。如今JAX已成为AI领域科研人员与工程师突破边界的基础设施,其生态系统还包含一系列专业工具库:如简化神经网络架构实现的Flax,以及集成前沿优化器的Optax。

基于JAX构建的JAX-Privacy,是一套用于构建与审计差分隐私模型的稳健工具包。它帮助研发人员快速高效地实现差分隐私(DP)算法,支持在大规模数据集上训练深度学习模型,并提供将隐私训练融入现代分布式工作流的核心工具。JAX-Privacy初版于2022年问世,旨在让外部研究者能复现验证我们在隐私训练方面的突破。如今它已发展为谷歌各研究团队的成果集成枢纽,持续将创新见解转化为DP训练与审计算法。

今天我们自豪地推出JAX-Privacy 1.0。新版本融合最新研究成果并采用模块化重构,让开发者能更轻松地构建融合尖端DP算法与JAX扩展能力的训练流程。

发展历程:为何需要JAX-Privacy
多年来,差分隐私始终是量化隐私泄露风险的黄金标准。DP能保证算法输出结果不会因数据集中单个个体(或样本)的存在与否而产生显著差异。

尽管DP理论体系成熟,但其在大规模机器学习中的实践仍面临挑战。最常用的差分隐私随机梯度下降法(DP-SGD)需要定制化批处理流程、逐样本梯度裁剪以及精确校准的噪声注入。这一过程计算密集,且在现代基础模型的规模下难以高效准确实现。

现有框架虽取得进展,但在扩展性与灵活性上仍有不足。从开创DP新算法到开发精密审计技术,我们始终在拓展隐私机器学习的边界。这要求我们打造一个既能保证正确性与高效性,又能原生支持前沿模型并行化与复杂度的工具库。

JAX的函数式编程范式及其强大转换功能(如实现自动向量化的vmap和支持单程序多数据并行化的shard_map)为此奠定了坚实基础。依托JAX构建的工具库可原生支持并行计算,助力跨多加速器与超级计算机的大规模模型训练。JAX-Privacy正是这一努力的结晶——这个历经实践检验的工具库已支撑谷歌内部生产集成,现在正式向更广社区开放。

JAX-Privacy的核心能力
JAX-Privacy通过精心设计的组件简化DP实现复杂度:

从研究到实践:安全微调大语言模型
JAX-Privacy最令人振奋之处在于其实际应用能力。该库专为支持现代LLM预训练与微调框架设计,典型例证即我们近期运用其构建模块训练的VaultGemma——当前全球性能最强的差分隐私大语言模型。

通过此次开源发布,我们期望开发者能通过流行的Keras框架,用寥寥数行代码即可安全微调大模型。我们特别提供了完整可运行的Gemma系列模型微调示例(该系列是谷歌DeepMind基于Gemini技术构建的开放模型),展示如何将JAX-Privacy应用于对话摘要生成、合成数据生成等任务,证明该库即使在最先进模型上也能实现一流效果。

通过简化DP集成,JAX-Privacy赋能开发者从零构建隐私保护应用——无论是为医疗应用微调聊天机器人,还是开发个性化金融顾问模型。它显著降低了隐私保护机器学习的技术门槛,让负责任的高性能AI技术更触手可及。

未来展望
我们期待JAX-Privacy为研究社区注入新活力。本次发布凝聚团队多年专注耕耘,是对隐私保护机器学习领域的重要献礼。我们相信这些工具将催生惠及所有人的创新研究浪潮。

我们将持续维护与发展该工具库,吸纳最新研究成果并响应社区需求。期待见证您运用JAX-Privacy创造的成果。立即访问GitHub代码库或安装PIP软件包,开启您的隐私保护机器学习之旅。

致谢
JAX-Privacy凝聚了以下人员的贡献:Leonard Berrada, Robert Stanforth, Brendan McMahan, Christopher A. Choquette-Choo, Galen Andrew, Mikhail Pravilov, Sahra Ghalebikesabi, Aneesh Pappu, Michael Reneer, Jamie Hayes, Vadym Doroshenko, Keith Rush, Dj Dvijotham, Zachary Charles, Peter Kairouz, Soham De, Samuel L. Smith, Judy Hanwen Shen。

英文来源:

Differentially private machine learning at scale with JAX-Privacy
November 12, 2025
Borja Balle, Staff Research Scientist, Google DeepMind, and Ryan McKenna, Senior Research Scientist, Google Research
We announce the release of JAX-Privacy 1.0, a library for differentially private machine learning on the high-performance computing library, JAX.
Quick links
From personalized recommendations to scientific advances, AI models are helping to improve lives and transform industries. But the impact and accuracy of these AI models is often determined by the quality of data they use. Large, high-quality datasets are crucial for developing accurate and representative AI models, however, they must be used in ways that preserve individual privacy.
That’s where JAX and JAX-Privacy come in. Introduced in 2020, JAX is a high-performance numerical computing library designed for large-scale machine learning (ML). Its core features — including automatic differentiation, just-in-time compilation, and seamless scaling across multiple accelerators — make it an ideal platform for building and training complex models efficiently. JAX has become a cornerstone for researchers and engineers pushing the boundaries of AI. Its surrounding ecosystem includes a robust set of domain-specific libraries, including Flax, which simplifies the implementation of neural network architectures, and Optax, which implements state-of-the-art optimizers.
Built on JAX, JAX-Privacy is a robust toolkit for building and auditing differentially private models. It enables researchers and developers to quickly and efficiently implement differentially private (DP) algorithms for training deep learning models on large datasets, and provides the core tools needed to integrate private training into modern distributed training workflows. The original version of JAX-Privacy was introduced in 2022 to enable external researchers to reproduce and validate some of our advances on private training. It has since evolved into a hub where research teams across Google integrate their novel research insights into DP training and auditing algorithms.
Today, we are proud to announce the release of JAX-Privacy 1.0. Integrating our latest research advances and re-designed for modularity, this new version makes it easier than ever for researchers and developers to build DP training pipelines that combine state-of-the-art DP algorithms with the scalability provided by JAX.
How we got here: The need for JAX-Privacy
For years, researchers have turned to DP as the gold standard for quantifying and bounding privacy leakage. DP guarantees that the output of an algorithm is nearly the same whether or not a single individual (or example) is included in the dataset.
While the theory of DP is well-established, its practical implementation in large-scale ML can be a challenge. The most common approach, differentially private stochastic gradient descent (DP-SGD), requires customized batching procedures, per-example gradient clipping, and the addition of carefully calibrated noise. This process is computationally intensive and can be difficult to implement correctly and efficiently, especially at the scale of modern foundation models.
Existing frameworks have made strides, but they often fall short in scalability or flexibility. Our work has consistently pushed the boundaries of private ML, from pioneering new DP algorithms to developing sophisticated auditing techniques. We needed a tool that could keep pace with our research — a library that was not only correct and efficient but also designed from the ground up to handle the parallelism and complexity of state-of-the-art models.
JAX's functional paradigm and powerful transformations, like vmap
(for automatic vectorization) and shard_map
(for single-program multiple-data parallelization), provided a strong foundation. By building on JAX, we could create a library that was parallelism-ready out-of-the-box, supporting the training of large-scale models across multiple accelerators and supercomputers. JAX-Privacy is the culmination of this effort, a time-tested library that has powered internal production integrations and is now being shared with the broader community.
What JAX-Privacy delivers
JAX-Privacy simplifies the complexities of DP by providing a suite of carefully engineered components:

谷歌研究进展

文章目录


    扫描二维码,在手机上阅读