oneflow.expand(input, *sizes)
expand 算子的功能簡(jiǎn)單來(lái)說(shuō)就是,能夠實(shí)現將輸入張量的沿著(zhù)大小為 1 的維度進(jìn)行復制,至于復制多少份由第二個(gè)參數決定(下面將該參數稱(chēng)為 expand_size)。
下面介紹 expand_size 設置的一些約定:
expand_size 維度大小大于等于輸入張量,如果大于輸入維度則相當于輸出會(huì )增加維度
對于輸入張量為 1 的維度, expand_size 對應維度可以設置為大于等于 1 的值
對于輸入張量不為 1 的維度, expand_size 對應維度只能設置為等于輸入或者 -1
新添加的維度只能加在開(kāi)頭且不能設置 -1,新增維度也就相當于將整個(gè)輸入張量進(jìn)行復制
示例1:
input_shape = [4, 3, 1, 2]exand_size = [4, 3, 5, 2] # 下面這些 expand_size 的設置都是合法的# [-1, 3, 5, 2] # [-1, -1, 5, 2] # [-1, -1, 5, -1] # [4, -1, 5, 2]# [4, -1, 5, -1]# [4, 3, 5, -1]out_shape = [4, 3, 5, 2]
示例2:
input_shape = [1, 4, 3, 5]exand_size = [2, 1, 2, 4, 3, 5] # 下面這些 expand_size 的設置都是合法的# [2, 1, 2, -1, 3, 5] # [2, 1, 2, -1, -1, 5] # [2, 1, 2, -1, -1, -1] # [2, 1, 2, 4, -1, 5] # [2, 1, 2, 4, -1, -1] # [2, 1, 2, 4, 3, -1] out_shape = [2, 1, 2, 4, 3, 5]
接下來(lái)介紹 expand 算子單卡視角下的實(shí)現思路,也就是先不用考慮分布式的情況。
從上一節的介紹可知 expand 算子的輸出張量的某個(gè)位置的值就是從輸入張量的某個(gè)位置復制過(guò)來(lái)的,所以問(wèn)題就轉化成了如何把輸出某個(gè)位置的索引映射到輸入對應位置的索引。
在介紹如何計算索引映射之前,首先來(lái)復習一下張量的 stride 屬性這個(gè)概念。對于內存連續的 n 維張量,可以通過(guò)其 stride 屬性,快速定位到該張量任意位置的索引 (x, y, z, k) 對應的一維索引。
舉個(gè)例子:
input_shape = [6, 3, 4, 5]stride = [60, 20, 5, 1] # 下面會(huì )介紹 stide 的計算方法input[x, y, z, k] == input_flatten[x * 60 + y * 20 + z * 5 + k * 1]
stride 每一維度的數值表示該維度索引每增加1,對應到內存上應該移動(dòng)的步長(cháng),stride 每一維的計算公式如下:

示例代碼:
# 最后一維初始化為1stride = [1]# 從后往前生成 stridefor i in range(len(input_shape) - 2, -1, -1): # 在 stride 數組開(kāi)頭插入元素 stride.insert(0, input_stride[0] * input_shape[i + 1])
接著(zhù)來(lái)看該如何計算 expand 算子的輸出索引到輸入索引的映射。
我們知道如果輸入張量某維度是 1,而 expanbd_size 對應的維度大于 1,相當于是將輸入張量會(huì )沿著(zhù)該維度進(jìn)行復制。也就是對于復制的維度來(lái)說(shuō),不管該輸出維度的索引是多少,都對應著(zhù)輸入張量該維度的索引 0。其實(shí)就是通過(guò)修改輸入張量的 stride 參數構造一個(gè)新的 output_stride ,該 output_stride 的計算方法就是:
如果 expand_size 某維度 i 值為 -1,或者與輸入張量的對應維度一致,則 output_stirde[i] = stride[i]
如果 expand_size 某維度 i 值大于 1,而輸入張量對應維度為 1,則 output_stride[i] = 0
對于 expand_size 維度大于輸入張量維度的情況,則對于新添加的維度 i,output_stride[i] = 0
計算 output_stirde 的示例代碼:
output_stride = []diff = len(expand_size) - len(input_shape)for i in range(len(expand_size) - 1, -1, -1): if i >= diff: if expand_size[i] == -1 or expand_size[i] == input_shape[i - diff]: output_stride.insert(0, input_stride[i - diff]) else: assert expand_size[i] >= 1 and input_shape[i - diff] == 1 output_stride.insert(0, 0) else: assert expand_size[i] >= 1 output_stride.insert(0, 0)
舉個(gè)例子:
input_shape = [4, 1, 3, 5]stride = [15, 15, 5, 1]exand_size = [2, 1, 4, 4, 3, 5] output_stride = [0, 0, 15, 0, 5, 1]# 輸出張量意位置的索引 (x, y, z, k, v, w)output[x, y, z, k, v, w] = input_flatten[x * 0 + y * 0 + z * 15 + k * 0 + v * 5 + w * 1]# 反向的計算邏輯input_grad_flatten[x * 0 + y * 0 + z * 15 + k * 0 + v * 5 + w * 1] += output_grad[x, y, z, k, v, w]
前向代碼鏈接:
https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/user/kernels/expand_kernel.cu#L30
反向代碼鏈接:
https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/user/kernels/expand_kernel.cu#L43
多卡一致性視角
接下來(lái)介紹 OneFlow 中添加算子與其他框架不一樣的地方。除了要正確實(shí)現單卡視角下的計算邏輯,還需要考慮多卡一致性視角下的邏輯,包括輸出形狀推理的邏輯、sbp 簽名的設置和實(shí)際計算的邏輯。
首先簡(jiǎn)單介紹一致性視角的概念:
OneFlow 提出了一致性視角(consistent view)的概念,用于簡(jiǎn)化分布式訓練。簡(jiǎn)單而言,在 OneFlow 的一致性視角下,集群被抽象為一臺“超級計算設備”。用戶(hù)無(wú)需關(guān)心集群中計算和通信的細節,只需要關(guān)心邏輯上的數據與計算??梢韵駟螜C單卡那樣去思考要和編程,就能進(jìn)行分布式訓練。
然后什么是 sbp:
sbp 是 OneFlow 發(fā)明的概念,描述了在一致性視角下的 數據與集群中真實(shí)的物理設備上的數據的映射關(guān)系。它由 split, broadcast, partial 的首字母組合而成。
split
表示真實(shí)物理設備上的張量,是將一致性視角的張量切分得到的。切分時(shí)需要指定切分的維度,而真實(shí)物理設備上的張量經(jīng)過(guò)拼接之后可以還原得到一致性視角的張量。
broadcast
表示一致性視角下的張量,在所有的真實(shí)物理設備上都有一份完整的復制。
partial
表示一致性視角下的張量與物理設備上的張量的形狀相同,但是對于物理設備上的值,只是一致性視角下張量的一部分。以 partial_sum 為例,如果我們將集群中所有設備的張量按位置相加,才可以還原得到一致性視角的張量。除了 sum 外,min、max 等操作也適用于 partial。
更多詳細內容可以參考:
https://docs.oneflow.org/v0.5.0/parallelism/02_sbp.html
所以在 Oneflow 中開(kāi)發(fā)算子,開(kāi)發(fā)者還需要為算子設置其輸入和輸出支持哪些 sbp 簽名的組合,這也是需要付出的額外學(xué)習成本。
而在一致性視角下,算子的實(shí)現邏輯有可能需要考慮,其在真實(shí)物理設備上的計算與邏輯上的計算(也就是一致性視角)不一致的地方。
比如對于 expand 算子,在真實(shí)物理設備上計算的時(shí)候,就可能需要修改用戶(hù)傳入的邏輯上的 expand_size。主要原因在于 expand 算子的 sbp 簽名支持對輸入進(jìn)行 split
具體代碼鏈接:
https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/user/ops/expand_op.cpp#L62
舉個(gè)具體的例子:
logical input_shape = [4, 3, 1, 2]logical stride = [6, 2, 2, 1]logical expand_size = [2, 4, 3, 4, 2]logical output_stride = [0, 6, 2, 0, 1]
假設用戶(hù)設置了輸入張量的 sbp 為 split(3) ,也就是對最后一維度進(jìn)行切分。且設置該邏輯張量放置在兩張卡上,則每張卡上的真實(shí)物理形狀為:
physical input_shape = [4, 3, 1, 1]physical stride = [3, 1, 1, 1]
則對于真實(shí)物理設備上的 expand_size 和 output_stride 都需要做修改:
physical expand_size = [2, 4, 3, 4, 1]physical output_stride = [0, 3, 1, 0, 1]
為什么 expand_size 需要修改呢?
首先在一致性視角下,每個(gè)物理設備上進(jìn)行實(shí)際計算的時(shí)候,實(shí)際上拿到的輸入大小是切分之后的物理形狀。
而對于上面的例子,輸入的在每個(gè)設備上的物理形狀變?yōu)?nbsp;[4, 3, 1, 1],而如果 expand_size 這時(shí)候仍然保持用戶(hù)設置的邏輯大小 [2, 4, 3, 4, 2],則在每個(gè)設備上的輸出大小是 [2, 4, 3, 4, 2],則輸出對應的邏輯形狀則是 [2, 4, 3, 4, 4],則輸出結果最后一維就比原來(lái)多了。
而由于用戶(hù)怎么設置 sbp 是運行時(shí)才能拿到的信息,所以在物理設備上進(jìn)行計算之前,都需要根據實(shí)際的輸入大小,重新計算 expand_size 和 output_stride。
具體代碼鏈接:
https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/user/kernels/expand_kernel.cu#L129
oneflow.repeat(input, *sizes)
repeat 算子的功能簡(jiǎn)單來(lái)說(shuō)就是,能夠實(shí)現將輸入張量的任意維度進(jìn)行復制,至于復制多少份由第二個(gè)參數決定(下面將該參數稱(chēng)為 repeat_size)。
下面介紹 repeat_size 設置的一些約定:
repeat_size 維度大小大于等于輸入張量,如果大于輸入維度則相當于輸出會(huì )增加維度
repeat_size 任意維度的值需要設置為大于等于 1 ,假設某維度設為 n, 則先當于輸入張量對應維度復制 n-1 份
新添加的維度只能加在開(kāi)頭且不能設置為小于1的值,新增維度也就相當于將整個(gè)輸入張量進(jìn)行復制,假設新增維度設為 n, 則先當于將輸入張量復制 n-1 份
repeat_size 任意維度的值其實(shí)也可以設置為 0,但是這里不考慮這種情況
則輸出張量每一維的大小計算方式如下:
對于非新增的維度:

對于新增的維度:

input_shape = [4, 1, 3, 5]repeat_size = [2, 1, 2, 4, 1, 1] output_shape = [2, 1, 8, 4, 3, 5]
其實(shí)仔細思考一下,可以感覺(jué)到 repeat 算子和 expand 算子其實(shí)是有聯(lián)系的,也就是 repeat 算子是可以通過(guò) expand 算子來(lái)實(shí)現。
舉些例子:
例子1:
input_shape = [5] repeat_size = [3] output_shape = [15]# 等價(jià)與以下操作input_shape = [5] reshaped_input_shape = [1, 5] expand_size = [3, 5] output_shape = [3, 5] reshaped_output_shape = [15]
例子2:
input_shape = [3, 1, 5]repeat_size = [5, 3, 1] output_shape = [15, 3, 5]# 等價(jià)于以下操作input_shape = [3, 1, 5]reshaped_input_shape = [1, 3, 1, 5]expand_size = [5, 3, 3, 5] output_shape = [5, 3, 3, 5]reshaped_output_shape = [15, 3, 5]
例子3:
input_shape = [3, 1, 5]repeat_size = [2, 5, 3, 1] output_shape = [2, 15, 3, 5]# 等價(jià)與以下操作input_shape = [3, 1, 5]reshaped_input_shape = [1, 3, 1, 5]expand_size = [2, 5, 3, 3, 5] output_shape = [2, 5, 3, 3, 5]reshaped_output_shape = [2, 15, 3, 5]
從上面的例子可以知道, repeat 操作可以用 reshape + expand + reshape 來(lái)代替,問(wèn)題就轉化成如何根據 input_shape 和 repeat_size 計算得到輸入的 reshape大小,expand_size 和輸出的 reshape 大小。
計算示例代碼:
input_reshape = [] output_reshape = [] expand_size = [] diff = len(repeat_size) - len(input_shape)for i in range(len(repeat_size) - 1, -1, -1): if i >= diff: if repeat_size[i] > 1: if input_shape[i - diff] > 1: input_reshape.insert(0, input_shape[i - diff]) input_reshape.insert(0, 1) expand_size.insert(0, input_shape[i - diff]) expand_size.insert(0, repeat_size[i]) output_reshape.insert(0, input_shape[i - diff] * repeat_size[i]) else: input_reshape.insert(0, input_shape[i - diff]) expand_size.insert(0, repeat_size[i]) output_reshape.insert(0, repeat_size[i]) else: input_reshape.insert(0, input_shape[i - diff]) expand_size.insert(0, input_shape[i - diff]) output_reshape.insert(0, input_shape[i - diff]) else: # 新增的維度 expand_size.insert(0, repeat_size[i]) output_reshape.insert(0, repeat_size[i]) new_tensor = flow.reshape(input, input_reshape) tmp_tensor = new_tensor.expand(*expand_size) out = flow.reshape(tmp_tensor, output_reshape)
不過(guò)這算是取巧的實(shí)現了 repeat 算子,因為替換成了reshape 和 expand 算子來(lái)實(shí)現,所以也不用考慮 sbp 的問(wèn)題了,不過(guò)后續為了性能還是需要單獨寫(xiě)一個(gè)算子實(shí)現的。
聯(lián)系客服