ALGM: Adaptive Local-then-Global Token Merging for Efficient Semantic Segmentation with Plain Vision Transformers
paper|code
Background & Motivation
具有高余弦相似度的token可以合并,而不会降低分割质量。
- CTS表明,在早期网络阶段进行局部token共享可以提高效率,而不会影响分割质量,但它需要一个预处理网络。 因此,我们的第一个目标是在网络浅层合并冗余符元,而无需预处理,同时保持分割质量。
- 像ToMe这样的token合并方法表明,逐渐合并整张图像上的冗余token可以大大提高效率,但全局范围内合并损害分割质量。 因此,我们的第二个目标是应用全局token合并以进一步提高效率,同时不会损害分割质量。
Challenge
如何创造一个新方法,既能像CTS一样在早期就合并局部Token,又能像ToMe一样在全局范围内高效合并,同时没有额外的预处理网络,不损害分割质量。
沿用余弦相似度的标准,发现随着模型加深:
- 在早期,它足以在局部区分开不同物体
- 在后期,它能在全局上更清晰地区分不同物体
Method
基于这些发现,提出了Adaptive Local-then-Global Merging (ALGM) module,该模块集成了两个token合并阶段。在第一网络层中,ALGM 采用局部合并策略。 在中间层采用全局合并机制,以减少全局token冗余。 此外,不预设token的合并数量,而是根据图像内容的语义复杂度动态决定合并token的数量。
Token相似度分析
在何种情况下以及何时,余弦相似度能够成为一种有效的指标,用于识别代表同一类别的标记,从而使其适合进行局部和全局合并。
提取并比较了分词器生成的token与在 ADE20K训练集中训练的 ViT-S的相似性。
(1)首先,分析了第一层转置前向层中k×k窗口内的局部相似性。如图 2a 所示,窗口大小 k 越小,余弦相似度就越能准确地反映token属于同一类别。因此,在第一层中,在小局部窗口内具有高余弦相似度的token很可能可以合并,而不会导致分割质量下降。
(2)计算整张图像中所有 Transformer 层的类别间和类别内token的余弦相似度来分析全局相似度。如图 2b 所示,早期层中的全局相似度并不能准确反映类别对应关系,因此不应将其用于识别需要合并的token。然而,在网络更深的部分,余弦相似度成为一种更好的衡量标准,可以用于在全局范围内识别可以合并的标记,而不会影响分割质量。
Adaptive Local-then-Global Merging(ALGM)
(a)早期层中的局部token相似性以及(b)中间层中的全局token相似性很可能是衡量token合并能力的指标。提出自适应局部-然后全局合并(ALGM)方法。首先在第一层使用条件局部平均池化(CLAP)模块进行局部合并。在中间层,采用基于 BSM算法的全局二分合并(GBM)模块进行全局合并。整个过程以一个token解合并模块结束,以恢复原始的token解析。
Local token merging.
如果一个Token和它在一个小窗口内的邻居们高度相似,就将它们合并。CLAP模块,它被放置在第一层(L1)的MHSA和MLP模块之间,用来实现这个功能。
Step 1.
它接收来自第一层(L1)的Token T’1,并将其重新排列成一个空间网格 T’G1。然后,定义k×k大小的窗口,并将每个窗口内的Tokens分组到不同的集合W中。
Step 2.
计算小组内所有Token之间的余弦相似度,并求出这些相似度的平均值μw。然后,根据相似度代表可合并性的假设,CLAP模块只合并那些平均相似度 μ w μw μw大于阈值 τ τ τ的窗口。
Step 3.
被选中的窗口 w w w内的所有Tokens,通过计算它们的平均值,合并成一个Token。这些被合并的Tokens的原始索引也会被存储起来,以备后续的“解合并”(unmerging)操作。完成后,合并产生的新Token和那些未被合并的Token被连接在一起,生成最终的输出,其数量小于或等于原始数量。
Global token merging.
Step 1.
token分组与图构建,分成两组,构造二分图
Step 2.
寻找最佳匹配,找到唯一一个最合适的合并对象。
Step 3.
应用相似度阈值,保证足够相似的token对才被允许进入最后的合并阶段。
Step 4.
对于所有经过前两轮筛选后仍然保留下来的边,其连接的token对将被合并,并且存储索引 。所有未参与合并的token和那些合并后更新了的token被拼接在一起,形成一个新的、数量更少的token集合,作为下一层的输入。
Token unmerging.
利用合并时记录的索引信息,通过“复制粘贴”的方式,将被合并的Token还原到其原始位置,从而恢复出与输入图像同样尺寸的特征表示。
这个过程的执行时机取决于下游的解码器:
如果解码器是Transformer(不怕乱序),就先解码,后还原,效率更高。
如果解码器是CNN(要求整齐),就先还原,后解码,以满足其输入要求。
Adaptive token merging.
在训练之前,使用想要应用 ALGM 的基础分割模型,并在训练集中进行对比测试。然后,在每一层 Ll 中提取 MHSA 块之后的token,计算所有token对之间的余弦相似度,并计算整个训练集的平均相似度 µ s i m µsim µsim 和标准差 σ s i m σsim σsim。根据这些统计数据,设置阈值 τ = µ s i m + σ s i m τ = µsim + σsim τ=µsim+σsim。使用此阈值,经过 CLAP 和 GBM 模块后的剩余令牌数量 N’ 和 N’’ 会因图像而异。在训练过程中,为了便于对图像和标记进行分组处理,确定每次分组的最大剩余token数量 N’ 和 N’',然后将这些数值应用于该批次中的所有图像。
算法实现
local_merge.py
import math
from typing import Callable, Tuple
import torch
import torch.nn
import torch.nn.functional as F
from einops import rearrange
import numpy as np
def conditional_pooling(
feat: torch.Tensor,
threshold:float,
window_size: Tuple[int, int],
) -> Tuple[Callable, Callable]:
with torch.no_grad():
ws_h, ws_w = int(window_size[0]), int(window_size[1])
stride_h, stride_w = ws_h, ws_w
num_token_window = stride_h * stride_w
x_cls, feat = feat[:, :1, :], feat[:, 1:, :]
B, N, D = feat.size()
base_grid_H = int(math.sqrt(N))
base_grid_W = base_grid_H
assert base_grid_H * base_grid_W == N and base_grid_H % ws_h == 0 and base_grid_W % ws_w == 0
feat = rearrange(feat, "b (h w) c -> b c h w", h=base_grid_H)
feat = rearrange(feat, 'b c (gh ps_h) (gw ps_w) -> b gh gw c ps_h ps_w', gh=base_grid_H//ws_h, gw=base_grid_W//ws_w)
b, gh, gw, c, ps_h, ps_w = feat.shape
# Flatten mxm window for pairwise operations
tensor_flattened = feat.reshape(b, gh, gw, c, -1)
# Expand dims for pairwise operations
tensor_1 = tensor_flattened.unsqueeze(-1)
tensor_2 = tensor_flattened.unsqueeze(-2)
# Compute cosine similarities
sims = F.cosine_similarity(tensor_1, tensor_2, dim=3)
# Exclude the self-similarity (i.e., similarity with oneself will be 1)
sims_mask = 1 - torch.eye(ps_h * ps_w).to(sims.device)
sims = sims * sims_mask
# Average similarities (excluding the self-similarity)
similarity_map = sims.sum(-1).sum(-1) / ((ps_h * ps_w) * (ps_h * ps_w - 1))
similarity_map = rearrange(similarity_map.unsqueeze(1), 'b c h w-> b (c h w)')
#--- adaptive section ---#
n_B, n_H = similarity_map.shape
node_mean = torch.tensor(threshold).cuda(sims.device)
node_mean=node_mean.repeat(1,n_H)
r = torch.ge(similarity_map, node_mean).sum(dim=1).min()
# -------------#
# get top k similar super patches
_, sim_super_patch_idxs = similarity_map.topk(r,dim=-1)
# --- creating the mergabel and unmergable super pathes
tensor = torch.arange(base_grid_H * base_grid_W).reshape(base_grid_H, base_grid_W).to(feat.device)
# Repeat the tensor to create a batch of size 2
tensor = tensor.unsqueeze(0).repeat(B, 1, 1)
# Apply unfold operation on last two dimensions to create the sliding window
windowed_tensor = tensor.unfold(1, ws_h, stride_h).unfold(2, ws_w, stride_w)
# Reshape the tensor to the desired shape
windowed_tensor = windowed_tensor.reshape(B, -1, num_token_window)
# Use torch.gather to collect the desired elements
gathered_tensor = torch.gather(windowed_tensor, 1, sim_super_patch_idxs.unsqueeze(-1).expand(-1, -1, num_token_window))
# Create a mask for all indices, for each batch
mask = torch.ones((B, windowed_tensor.shape[1]), dtype=bool).to(feat.device)
# Create a tensor that matches the shape of indices and fill it with False
mask_values = torch.zeros_like(sim_super_patch_idxs, dtype=torch.bool).to(feat.device)
# Use scatter_ to update the mask. This will set mask[b, indices[b]] = False for all b
mask.scatter_(1, sim_super_patch_idxs, mask_values)
# Get the remaining tensor
remaining_tensor = windowed_tensor[mask.unsqueeze(-1).expand(-1, -1, num_token_window)].reshape(B, -1, num_token_window)
unm_idx = remaining_tensor.reshape(B, -1).sort(dim=-1).values.unsqueeze(-1)
dim_index = (num_token_window)- 1
src_idx= gathered_tensor[:, :, :dim_index].reshape(B, -1).unsqueeze(-1)
dst_idx= gathered_tensor[:, :, dim_index].reshape(B, -1).unsqueeze(-1)
merge_idx = torch.arange(src_idx.shape[1]//dim_index).repeat_interleave(dim_index).repeat(B, 1).unsqueeze(-1).to(feat.device)
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
# TODO: num_token_window can be undefined
x_cls , x_feat = x[:, :1, :], x[:, 1:, :]
n, t1, c = x_feat.shape
src = x_feat.gather(dim=-2, index=src_idx.expand(n, r*dim_index, c))
dst = x_feat.gather(dim=-2, index=dst_idx.expand(n, r, c))
unm = x_feat.gather(dim=-2, index=unm_idx.expand(n, t1 - (r*num_token_window), c))
dst = dst.scatter_reduce(-2, merge_idx.expand(n,r*dim_index, c), src, reduce=mode)
x = torch.cat([dst, unm], dim=1)
x = torch.cat((x_cls, x), dim=1)
return x
return merge
def merge_wavg(
merge: Callable, x: torch.Tensor, size: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
if size is None:
size = torch.ones_like(x[..., 0, None])
x = merge(x * size, mode="sum")
size = merge(size, mode="sum")
x = x / size
return x, size
def merge_source(
merge: Callable, x: torch.Tensor, source: torch.Tensor = None
) -> torch.Tensor:
if source is None:
n, t, _ = x.shape
source = torch.eye(t, device=x.device)[None, ...].expand(n, t, t)
source = merge(source, mode="amax")
return source
global_merge.py
import math
from typing import Callable, Tuple
import torch
def do_nothing(x, mode=None):
return x
def turbo_matching(
metric: torch.Tensor,
layer_idx:int,
source: torch.Tensor,
class_token: bool = False,
distill_token: bool = False,
) -> Tuple[Callable, Callable]:
protected = 0
if class_token:
protected += 1
if distill_token:
protected += 1
t = metric.shape[1]
r = (t - protected) // 2
if r <= 0:
return do_nothing, do_nothing
with torch.no_grad():
B,m_t,um_t = source.shape
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = metric[..., ::2, :], metric[..., 1::2, :]
scores = a @ b.transpose(-1, -2)
if class_token:
scores[..., 0, :] = -math.inf
if distill_token:
scores[..., :, 0] = -math.inf
node_max, node_idx = scores.max(dim=-1)
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
# ------------------ start addaptive section ---------
i = layer_idx
n_B, n_H = node_max.shape
node_mean= torch.add(node_max[:,1:].mean(dim=1).mean(),node_max[:,1:].std(dim=1).mean()/i)
node_mean=node_mean.repeat(1,n_H)
r = torch.ge(node_max, node_mean).sum(dim=1).min()
# ------------------ end addaptive section ---------
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
src_idx = edge_idx[..., :r, :] # Merged Tokens
dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)
if class_token:
# Sort to ensure the class token is at the start
unm_idx = unm_idx.sort(dim=1)[0]
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = x[..., ::2, :], x[..., 1::2, :]
n, t1, c = src.shape
unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
if distill_token:
return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1)
else:
return torch.cat([unm, dst], dim=1)
return merge
Apply ALGM between the attention and mlp blocks
class TurboBlock(Block):
"""
Modifications:
- Apply ALGM between the attention and mlp blocks
"""
def _drop_path1(self, x):
return self.drop_path1(x) if hasattr(self, "drop_path1") else self.drop_path(x)
def _drop_path2(self, x):
return self.drop_path2(x) if hasattr(self, "drop_path2") else self.drop_path(x)
def forward(self, x: torch.Tensor ) -> torch.Tensor:
attn_size = self._turbo_info["size"] if self._turbo_info["prop_attn"] else None
x_attn, metric = self.attn(self.norm1(x),attn_size)
x = x + self._drop_path1(x_attn)
layer_idx = self._turbo_info["selected_layers"].pop(0)
if self._turbo_info["source"] is None: # if layer_idx == 1:
merge = conditional_pooling(
x,
self._turbo_info["threshold"],
self._turbo_info["window_size"],
)
if self._turbo_info["trace_source"]:
self._turbo_info["source"] = merge_source(
merge, x, self._turbo_info["source"]
)
x, self._turbo_info["size"] = merge_wavg(merge, x, self._turbo_info["size"])
else:
merge = turbo_matching(
x,
layer_idx,
self._turbo_info["source"],
self._turbo_info["class_token"],
self._turbo_info["distill_token"],
)
if self._turbo_info["trace_source"]:
self._turbo_info["source"] = merge_source(
merge, x, self._turbo_info["source"]
)
x, self._turbo_info["size"] = merge_wavg(merge, x, self._turbo_info["size"])
x = x + self._drop_path2(self.mlp(self.norm2(x)))
return x
实验结果
Inspire
- local划分、合并的策略是否在low-level像素级任务上是有效的,替代window attention(复杂度)