对比pytorch、tensorflow、jax、theano,我发现都在关注两大问题-ag真人官方网

对比pytorch、tensorflow、jax、theano,我发现都在关注两大问题-ag真人官方网

来源:csdn博客 | 2022-12-21 15:57:02 |

作者|王益


(资料图)

oneflow社区编译

翻译|杨婷

最近,我在处理 pytorch 分布式和 torchrec 相关的工作,为此,我开始学习 pytorch 2.0。在业余时间,我也在跟着alpa作者学习jax和xla。如今回顾这些技术,我发现它们的关注点似乎都是如下两个问题:

包含自动求导和并行在内的函数转换,例如 vmap, pmap 和 pjit 等;

异构计算,cpu 负责控制流,gpu/tpu 负责张量计算和集合通信。

本文档中的所有例子都支持在 colab 中运行:

theano/aesara

https://colab.research.google.com/drive/1eg7c5wmnokhxgxq46pna30dxucklqupz

tensorflow 1.x

https://colab.research.google.com/drive/1jc0epg2aaxbihevtozm_33mmhc70rzqz?usp=sharing

tensorflow 2.x

https://colab.research.google.com/drive/1pbftzj9e2_fyiiuoztpexmvlfky_g2nv

pytorch 1.x

https://colab.research.google.com/drive/1v4henl-ij-c6vt5h9w1nc2te85d8vdjk

jax

https://colab.research.google.com/drive/1plfijlizattibd3tbjiebsgpxvq9lvlg

functorch/pytorch 2.x

https://colab.research.google.com/drive/1o-yj-5g1v084rdairw2pqfajog7ty951

1

函数转换

“函数转换”意为将一个程序转变成另一个程序,最常见的例子是自动求导(autograd)。自动求导采用用户编写的前向过程并创建后向过程,对于用户来说,编写自动求导通常都太过复杂。函数转换的主要难点在于:在编写函数转换算法时以何种方式表示输入和输出过程。

theano:显式地构建 ir

theano是最早的深度学习工具之一,也就是如今为人们所熟知的aesara项目。theano有一个允许用户在内存中将ir构建为数据结构的api,因此theano可实现自动求导,并将结果输出为 python 函数。  

import aesarafrom aesara import tensor as ata = at.dscalar("a") # define placeholders, which have no values.b = at.dscalar("b")c = a * b # c now contains the ir of an expression.ttdc = aesara.grad(c, a) # convert the ir in c into another one, dcf_dc = aesara.function([a, b], dc) # convert the ir into a python function,assert f_dc(1.5, 2.5) == 2.5 # so we can call it.

tensorflow 1.x:用于运行 ir 的虚拟机

tensorflow 1.x明确保留了构建ir的想法。若在tensorflow中运行上述示例,结果不会有什么差别;但倘若在tensorflow 1.x中来运行,最大的差别在于:我们不会将后向 ir 转换为 python 函数,并使用 python 解释器来运行。相反,我们会在tensorflow runtime中来运行。  

import tensorflow.compat.v1 as tf # tensorflow 1.x apiimport numpy as nptf.disable_eager_execution()a = tf.placeholder(tf.float32, shape=())b = tf.placeholder(tf.float32, shape=())c = a * bdc = tf.gradients(c, [a], stop_gradients=[a, b])with tf.compat.v1.session() as sess: # tensorflow has a runtime to execute the ir, x = np.single(2) # so, no converting it into python code. y = np.single(3) print(sess.run(dc, feed_dict={a:x, b:y}))

pytorch 1.x:没有前向ir

pytorch不会像theano或tensorflow那样将前向传播转换为ir。反之,pytorch 使用 python 解释器来运行前向传播。这样做的弊端在于会在运行期间生成表示后向传播的 ir,我们称之为eager模式(动态图模式)。  

import torcha = torch.tensor(1.0, requires_grad=true) # these are not placeholders, but values.b = torch.tensor(2.0)c = a * b # evaluates c and derives the ir of the backward in c.grad_fn_.c.backward() # executes c.grad_fn_.print(c.grad)

tensorflow 2.x: 梯度带

tensorflow 2.x增加了一个像pytorch api的eager模式api。此 api 追踪前向传播如何运行名为梯度带(gradienttape)的 ir 。tensorflow 2.x可以从这个跟踪中找出后向传播。

import tensorflow as tfa = tf.variable(1.0) # like pytorch, these are values, not placehodlers. b = tf.variable(2.0)with tf.gradienttape() as tape: c = a * bdcda = tape.gradient(c, a)print(dcda)

jax

jax 不会向用户公开诸如梯度带等方面的低级别细节。简单说来,jax的思维方式为:将输入和输出都用python函数来表示。

import jax a = 2.0b = 3.0jax.grad(jax.lax.mul)(a, b)  # compute c = a * b w.r.t. a. the result is b=3. jax.jit(jax.grad(jax.lax.mul))(a,b)jax.experimental.pjit(jax.grad(jax.lax.mul), device_mesh(ntpus))(a,b)

对于想要自己编写的函数转换的高级用户,他们可以调用make_jaxpr等低级 api 来访问 ir,称为 jaxpr。

jax.make_jaxpr(jax.lax.mul)(2.0, 3.0) # returns the ir representing jax.lax.mul(2,3)jax.make_jaxpr(jax.grad(jax.lax.mul))(2.0, 3.0) # returns the ir of grad(mul)(2,3)

functorch

functorch和jax类似,都是基于pytorch的函数转换。

import torch, functorcha = torch.tensor([2.0])b = torch.tensor([3.0])functorch.grad(torch.dot)(a, b)

jax的make_jaxpr类似于functorch的make_fx。

def f(a, b): return torch.dot(a, b) # have to wrap the builtin function dot into f. # 必须将内置函数dot转换成f. print(functorch.make_fx(f)(a, b).code)print(functorch.make_fx(functorch.grad(f))(a, b).code)

tensorflow 2.x、jax 和 functorch 都为前向传递构建了一个 ir,但 pytorch eager模式没有。ir 不仅可用于自动求导,还可用于其他类型的函数转换。在下列例子中,functorch.compile.aot_function调用了回调函数print_compile_fn两次,分别用于前向和后向传播。

from functorch.compile import aot_functionimport torch.fx as fxdef print_compile_fn(fx_module, args): print(fx_module) return fx_moduleaot_fn = aot_function(torch.dot, print_compile_fn)aot_fn(a, b)

2高阶导数

pytorch

import torchfrom torch import autogradx = torch.tensor(1., requires_grad = true)y = 2*x**3 8first_derivative = autograd.grad(y, x, create_graph=true)print(first_derivative)second_derivative = autograd.grad(first_derivative, x)print(second_derivative)

tensorflow 2.x

import tensorflow as tfx = tf.variable(1.0)with tf.gradienttape() as outer_tape: with tf.gradienttape() as tape: y = 2*x**3 8 dy_dx = tape.gradient(y, x) print(dy_dx) d2y_dx2 = outer_tape.gradient(dy_dx, x) print(d2y_dx2)

jax

def f(a): return 2*a**3 8print(jax.grad(f)(1.0))print(jax.grad(jax.grad(f))(1.0))

3动态控制流

动态控制流(dynamic control flows)有两个层级:在 cpu 上运行的粗粒度级别和在 gpu /tpu 上运行的细粒度级别。本部分主要介绍在 cpu 上运行的粗粒度级别的动态控制流。下面我们将用(if/else)条件语句作为例子检验深度学习工具。

tensorflow 1.x

在 tensorflow 1.x 中,我们需要将条件语句显式构建到 ir 中。此时条件语句是一个特殊的运算符 tf.cond。

def f1(): return tf.multiply(a, 17)def f2(): return tf.add(b, 23)r = tf.cond(tf.less(a, b), f1, f2)with tf.compat.v1.session() as sess: # tensorflow has a runtime to execute the ir, print(sess.run(r, feed_dict={a:x, b:y}))

tensorflow 2.x

tensorflow 2.x 支持使用 tf.cond 和 tf.while_loop 显式构建控制流。此外,实验项目google/tangent中有autograph功能,它可以将python控制流转换为tf.cond或tf.while_loop。此功能利用了 python 解释器支持的函数和函数源代码。例如下面的g函数调用了 python 的标准库将源代码解析为 ast,然后调用 ssa 表单来理解控制流。

def g(x, y): if tf.reduce_any(x < y): return tf.multiply(x, 17) return tf.add(y, 23) converted_g = tf.autograph.to_graph(g)import inspectprint(inspect.getsource(converted_g))

jax

由于部分python语法很复杂,所以通过解析源代码来理解控制流就显得很困难,这就导致autograph经常出错。但如果这种方法很简单,那么python开发者社区也不会在构建python编译器时失败这么多次了。正是由于有这种挑战的存在,必须要明确地将控制流构建到 ir 中。为此,jax 提供了 jax.lax.cond 和 jax.lax.for_loop函数。

jax.lax.cond(a < b, lambda : a*17, lambda: b 23)

考虑到这一点,你可能会觉得我们可以使用递归算法。但是下面用于计算阶乘的递归无法用jax跟踪。

def factorial(r, x): return jax.lax.cond(x <= 1.0, lambda: r, lambda: factorial(r*x, x-1))factorial(1.0, 3.0)

可能你还想调用factorial来计算 3!=6。但这会让递归深度超过最大值,因为递归不仅依赖于条件,还依赖于函数定义和调用。

pytorch

pytorch最初是python-native。正如前文所说,由于多功能调度机制,grad 和 vamp 的函数转换都是即时的。值得注意的是:

相比theano 和 tensorflow构建ir后的函数转换,即时函数转换效率更高。

在进行grad和vmap 时,jax也是即时函数转换。然而像pamp和pjit等更复杂的函数转换需要对整个计算过程进行概述,在这个过程中ir是必不可少的。

由于ir在pmap 和 pjit中的必要性,pytorch社区最近添加了torch.condpytorch/pytorch#83154  

4分布式计算

根据执行代码或 ir 的不同方式,在使用 python 解释器或runtime时,有两种分布式计算方法。

python-native

theano和pytorch采用了python-native分布式计算方式。这种分布式训练工作包含多个python解释器进程。这导致出现了以下结果。

打包和运行(pack and run)。由于这些 python 进程在不同的host上运行,因此我们需要打包用户程序和依赖项,并将它们发送到这些host上去运行。一直以来torchx负责了这个打包过程。它支持例如docker和torch.package等各种打包格式,并且可以与各种集群管理器配合使用,如kubernetes和slurm。

单程序多数据(spmd)。由于将用户程序发送到各种host上要依赖于打包,与其他权重较轻的方式(如通过 rpc 发送代码)相比,这种方式不太灵活,因此,我们通常只发送一个程序。当所有这些进程运行同一程序时,这个作业就变成了单程序多数据(spmd)作业。

python-native spmd

下面是一个简单的spmd pytorch程序,我们可以在相同或不同的host上使用进程运行这个程序。在这个过程中,我们只需要调用all_gather。真正的分布式训练程序会调用更高级别的api,例如torch.nn.parallel.distributeddataparallel 和 torchrec.distributedmodelparallel, 然后再调用低级 api,例如 all_gather 和 all_reduce。

import osimport torchfrom torch import distributed as distdef main(): use_gpu = torch.cuda.is_available() local_rank = int(os.environ.get("local_rank", "0")) local_world_size = int(os.environ.get("local_world_size", "0")) device = torch.device(f"cuda:{local_rank}" if use_gpu else "cpu") dist.init_distributed(backend="nccl") lst = torch.tensor([local_rank 100]).to(device) # placeholder rlt_lst = [torch.zeros_like(lst) for _ in range(local_world_size)] dist.all_gather(rlt_lst, lst, async_op=false)    print("after broadcasting:", rlt_lst)

python-native non-spmd

pytorch 不仅限于 spmd 式的分布式训练。它还通过torch.distributed.pipeline.sync.pipe和pippy project提供流水并行,其中流水并行的各个阶段在不同的设备上运行不同的程序。这些阶段常通过 torch.rpc 包来沟通。

分布式运行时机制

分布式 tensorflow 作业由运行 tensorflow runtime 程序的进程组成,而不是由 python 解释器组成。此分布式运行时作业执行 tensorflow graph (ir),它是由执行用户程序的 python 解释器生成。

用户程序可以使用低级api(如 tf.device)去指定作业要运行什么操作、在哪台设备和主机上运行等等。因为api有runtime,所以可以做到这一点。

with tf.device("/job:bar/task:0/device:gpu:2"):    # ops created here have the fully specified device above

与pytorch一样,tensorflow也为分布式训练提供了高级api tf.distributed.strategy,keras和dtensor。

strategy = tf.distribute.mirroredstrategy() \ if tf.config.list_physical_devices("gpu") \           else tf.distribute.get_strategy()with strategy.scope(): model = tf.keras.sequential([tf.keras.layers.dense(1, input_shape=(1,))])model.compile(loss="mse", optimizer="sgd")

分布式运行时极大地方便了训练服务的维护,因为我们不再将用户程序打包到集群上运行。相反,我们打包运行时程序,因为相比用户程序,运行时程序更加统一。

混合理念

jax 支持 python-native 和分布式运行时。

jax 提供例如vmap、pmap 和 pjit的函数转换,这可以将 python 函数转换为分布式程序。

(本文经授权后由oneflow社区编译,译文转载请联系获得授权。原文:https://quip.com/y8qtayv4exrg)

其他人都在看

下载量突破10亿,minio的开源启示录

关于chatgpt的一切;cuda入门之矩阵乘

李白:你的模型权重很不错,可惜被我没收了

单rtx 3090训练yolov5s,时间减少11小时

openai掌门sam altman:ai下一个发展阶段

比快更快,开源stable diffusion刷新作图速度

oneembedding:单卡训练tb级推荐模型不是梦

欢迎star、试用oneflow最新版本:github - oneflow-inc/oneflow: oneflow is a deep learning framework designed to be user-friendly, scalable and efficient.oneflow is a deep learning framework designed to be user-friendly, scalable and efficient. - github - oneflow-inc/oneflow: oneflow is a deep learning framework designed to be user-friendly, scalable and efficient.https://github.com/oneflow-inc/oneflow/

关键词:

ag真人官方网 ag真人官方网的版权所有.

联系网站:920 891 263@qq.com
网站地图