From c810c32d6246f9f98fad8fed6f9292548799bed3 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 11:55:45 +0000 Subject: [PATCH 01/35] initial commit of sub-quadratic attention source from https://github.com/AminRezaei0x443/memory-efficient-attention. --- .../models/sub_quadratic_attention.py | 180 ++++++++++++++++++ .../utils/attention_slicing_utils.py | 36 ++++ 2 files changed, 216 insertions(+) create mode 100644 src/diffusers/models/sub_quadratic_attention.py create mode 100644 src/diffusers/utils/attention_slicing_utils.py diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py new file mode 100644 index 000000000000..0d6db33508c6 --- /dev/null +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -0,0 +1,180 @@ +# original source: +# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py +# license: +# unspecified +# credit: +# Amin Rezaei (primary author) +# xloem +# calc_fn +# sparse broadcasting for bias, mask, weights +# flattened conditions for clarity +# Hyungon Ryu (device arg fix) +# implementation of: +# Self-attention Does Not Need O(n2) Memory": +# https://arxiv.org/abs/2112.05682v2 + +import torch +from torch.utils.checkpoint import checkpoint +from ..utils.attention_slicing_utils import dynamic_slice, map_pt, scan +import math + + +def _query_chunk_attention(query_idx, query, key, value, + mask, bias, key_chunk_size=4096, + mask_calc_fn=None, + bias_calc_fn=None, + weights_calc_fn=None, + calc_fn_data=None): + num_kv, num_heads, k_features = key.shape[-3:] + v_features = value.shape[-1] + num_q = query.shape[-3] + key_chunk_size = min(key_chunk_size, num_kv) + query = query / math.sqrt(k_features) + + def summarize_chunk(key_idx, query, key, value, mask, bias): + attn_weights = torch.einsum('...qhd,...khd->...qhk', query, key) + if bias_calc_fn is not None: + bias = bias_calc_fn(query_idx, key_idx, bias, attn_weights, calc_fn_data) + if bias is not None: + bias = torch.einsum('...hqk->...qhk', bias) + attn_weights = attn_weights + bias + if mask_calc_fn is not None: + mask = mask_calc_fn(query_idx, key_idx, mask, attn_weights, calc_fn_data) + if mask is not None: + big_neg = torch.finfo(attn_weights.dtype).min + big_neg = torch.tensor(big_neg, device=mask.device, dtype=torch.float32) + mask = torch.einsum('...hqk->...qhk', mask) + attn_weights = torch.where(mask, attn_weights, big_neg) + if weights_calc_fn is not None: + attn_weights = weights_calc_fn(query_idx, key_idx, attn_weights, calc_fn_data) + max_score, _ = torch.max(attn_weights, -1, keepdim=True) + max_score = max_score.detach() + exp_weights = torch.exp(attn_weights - max_score) + exp_values = torch.einsum('...vhf,...qhv->...qhf', value, exp_weights) + max_score = torch.einsum('...qhk->...qh', max_score) + return exp_values, exp_weights.sum(dim=-1), max_score + + def chunk_scanner(chunk_idx): + key_chunk = dynamic_slice(key, tuple([0] * (key.ndim - 3)) + (chunk_idx, 0, 0), + tuple(key.shape[:-3]) + (key_chunk_size, num_heads, k_features)) + value_chunk = dynamic_slice(value, tuple([0] * (value.ndim - 3)) + (chunk_idx, 0, 0), + tuple(value.shape[:-3]) + (key_chunk_size, num_heads, v_features)) + + if bias is None: + bias_chunk = None + elif bias.shape[-1] == 1: + bias_chunk = bias + elif bias.shape[-1] == num_kv: + bias_chunk = dynamic_slice(bias, tuple([0] * (bias.ndim - 3)) + (0, 0, chunk_idx), + tuple(bias.shape[:-3]) + (bias.shape[-3], bias.shape[-2], key_chunk_size)) + else: + raise TypeError(f'bias.shape[-1] == {bias.shape[-1]} must broadcast with key.shape[-3] == {num_kv}') + + if mask is None: + mask_chunk = None + elif mask.shape[-1] == 1: + mask_chunk = mask + elif mask.shape[-1] == num_kv: + mask_chunk = dynamic_slice(mask, tuple([0] * (mask.ndim - 3)) + (0, 0, chunk_idx), + tuple(mask.shape[:-3]) + (mask.shape[-3], mask.shape[-2], key_chunk_size)) + else: + raise TypeError(f'bias.shape[-1] == {bias.shape[-1]} must broadcast with key.shape[-3] == {num_kv}') + + return checkpoint(summarize_chunk, chunk_idx, query, key_chunk, value_chunk, mask_chunk, bias_chunk) + + chunk_values, chunk_weights, chunk_max = map_pt( + chunk_scanner, xs=torch.arange(0, num_kv, key_chunk_size)) + + global_max, _ = torch.max(chunk_max, 0, keepdim=True) + max_diffs = torch.exp(chunk_max - global_max) + chunk_values *= torch.unsqueeze(max_diffs, -1) + chunk_weights *= max_diffs + + all_values = chunk_values.sum(dim=0) + all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) + return all_values / all_weights + + +def efficient_dot_product_attention(query, key, value, + mask=None, bias=None, + query_chunk_size=1024, + key_chunk_size=4096, + bias_calc_fn=None, + mask_calc_fn=None, + weights_calc_fn=None, + calc_fn_data=None): + """Computes efficient dot-product attention given query, key, and value. + This is efficient version of attention presented in + https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. + Note: query, key, value needn't have any batch dimensions. + Args: + query: queries for calculating attention with shape of + `[batch..., q_length, num_heads, qk_depth_per_head]`. + key: keys for calculating attention with shape of + `[batch..., kv_length, num_heads, qk_depth_per_head]`. + value: values to be used in attention with shape of + `[batch..., kv_length, num_heads, v_depth_per_head]`. + bias: bias for the attention weights. This should be broadcastable to the + shape `[batch..., num_heads, q_length, kv_length]`. + This can be used for incorporating padding masks, proximity bias, etc. + mask: mask for the attention weights. This should be broadcastable to the + shape `[batch..., num_heads, q_length, kv_length]`. + Attention weights are masked out if their corresponding mask value + is `False`. + query_chunk_size: int: query chunks size + key_chunk_size: int: key chunks size + bias_calc_fn: a bias calculation callback for each chunk, of form + `(q_offset, k_offset, bias_chunk, attn_weights, calc_fn_data) -> bias`. + This can be used for incorporating causal masks, padding masks, + proximity bias, etc. + mask_calc_fn: a mask calculation callback for each chunk, of form + `(q_offset, k_offset, mask_chunk, attn_weights, calc_fn_data) -> mask`. + This can be used for incorporating causal or other large masks. + Attention weights are masked out if their corresponding mask value + is `False`. + weights_calc_fn: a general attn_weights callback for each chunk, of form + `(q_offset, k_offset, attn_weights, calc_fn_data) -> attn_weights`. + attn_weights has shape of + `[batch..., q_chunk_size, num_heads, k_chunk_size]`. + This can be used to implement complex weights processing in a memory + efficient way. + calc_fn_data: optional pure data to pass to each per-chunk call of + bias_calc_fn, mask_calc_fn, and weights_calc_fn. + weights_calc_data: pure_data to pass with each call to weights_calc_fn + Returns: + Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`. + """ + num_q, num_heads, q_features = query.shape[-3:] + num_kv = key.shape[-3] + + def chunk_scanner(chunk_idx, _): + query_chunk = dynamic_slice(query, tuple([0] * (query.ndim - 3)) + (chunk_idx, 0, 0), + tuple(query.shape[:-3]) + (min(query_chunk_size, num_q), num_heads, q_features)) + + if mask is None: + mask_chunk = None + elif mask.shape[-2] == 1: + mask_chunk = mask + elif mask.shape[-2] == num_q: + mask_chunk = dynamic_slice(mask, tuple([0] * (mask.ndim - 3)) + (0, chunk_idx, 0), + tuple(mask.shape[:-3]) + (mask.shape[-3], min(query_chunk_size, num_q), mask.shape[-1])) + else: + raise TypeError(f'mask.shape[-2] == {mask.shape[-2]} must broadcast with query.shape[-3] == {num_q}') + + if bias is None: + bias_chunk = None + elif bias.shape[-2] == 1: + bias_chunk = bias + elif bias.shape[-2] == num_q: + bias_chunk = dynamic_slice(bias, tuple([0] * (bias.ndim - 3)) + (0, chunk_idx, 0), + tuple(bias.shape[:-3]) + (bias.shape[-3], min(query_chunk_size, num_q), bias.shape[-1])) + else: + raise TypeError(f'bias.shape[-2] == {bias.shape[-2]} must broadcast with query.shape[-3] == {num_q}') + return (chunk_idx + query_chunk_size, + _query_chunk_attention(chunk_idx, query_chunk, key, value, mask_chunk, bias_chunk, key_chunk_size=key_chunk_size, + bias_calc_fn=bias_calc_fn, mask_calc_fn=mask_calc_fn, + weights_calc_fn=weights_calc_fn, calc_fn_data=calc_fn_data)) + + _, res = scan(chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size)) + rl = [res[i] for i in range(res.shape[0])] + return torch.cat(rl, dim=-3) diff --git a/src/diffusers/utils/attention_slicing_utils.py b/src/diffusers/utils/attention_slicing_utils.py new file mode 100644 index 000000000000..44a7af9b773d --- /dev/null +++ b/src/diffusers/utils/attention_slicing_utils.py @@ -0,0 +1,36 @@ +# original source: +# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/utils.py +# license: +# unspecified +# credit: +# Amin Rezaei (primary author) +# Hyungon Ryu (device arg fix) +# implementation of: +# Self-attention Does Not Need O(n2) Memory": +# https://arxiv.org/abs/2112.05682v2 +import torch +import numpy as np + + +def dynamic_slice(x, starts, sizes): + # start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i]) + starts = [np.clip(starts[i], 0, x.shape[i] - sizes[i]) for i in range(len(starts))] + for i, (start, size) in enumerate(zip(starts, sizes)): + x = torch.index_select(x, i, torch.tensor(range(start, start + size), device=x.device)) + return x + + +def map_pt(f, xs): + t = [f(x) for x in xs] + return tuple(map(torch.stack, zip(*t))) + + +def scan(f, init, xs, length=None): + if xs is None: + xs = [None] * length + carry = init + ys = [] + for x in xs: + carry, y = f(carry, x) + ys.append(y) + return carry, torch.stack(ys) \ No newline at end of file From c9b3b9f591e6a0a698a4a7f5f6db1bbafa982f84 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 13:19:26 +0000 Subject: [PATCH 02/35] invoke efficient_dot_product_attention(). not currently giving correct results. --- src/diffusers/models/cross_attention.py | 76 ++++++++++++++++++++++++- 1 file changed, 75 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 98173cb8a406..68f834a4747b 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Union +from .sub_quadratic_attention import efficient_dot_product_attention import torch import torch.nn.functional as F -from torch import nn +from torch import nn, Tensor from ..utils.import_utils import is_xformers_available @@ -145,6 +146,20 @@ def set_attention_slice(self, slice_size): processor = CrossAttnProcessor() self.set_processor(processor) + + def set_subquadratic_attention( + self, + query_chunk_size = 1024, + kv_chunk_size: Optional[int] = None, + ): + r""" + Args: + query_chunk_size (`int`, *optional*, defaults to `1024`) + kv_chunk_size (`int`, *optional*, defaults to `None`): if None, sqrt(key tokens) is used. + """ + processor = SubQuadraticCrossAttnProcessor(query_chunk_size, kv_chunk_size) + + self.set_processor(processor) def set_processor(self, processor: "AttnProcessor"): self.processor = processor @@ -236,6 +251,65 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No return hidden_states +class SubQuadraticCrossAttnProcessor: + query_chunk_size: int + kv_chunk_size: Optional[int] + def __init__( + self, + query_chunk_size = 1024, + kv_chunk_size: Optional[int] = None + ): + r""" + Args: + query_chunk_size (`int`, *optional*, defaults to `1024`) + kv_chunk_size (`int`, *optional*, defaults to `None`): if None, sqrt(key tokens) is used. + """ + self.query_chunk_size = query_chunk_size + self.kv_chunk_size = kv_chunk_size + + def __call__( + self, + attn: CrossAttention, + hidden_states: Tensor, + encoder_hidden_states: Optional[Tensor]=None, + attention_mask: Optional[Tensor]=None, + ): + encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + + assert attention_mask is None, "attention-mask not currently tested for SubQuadraticCrossAttnProcessor." + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + dtype = query.dtype + # TODO: do we still need this given how we delay the division? + if attn.upcast_attention: + query = query.float() + key = key.float() + value = value.float() + + hidden_states = efficient_dot_product_attention( + query, + key, + value, + query_chunk_size=self.query_chunk_size, + key_chunk_size=self.kv_chunk_size, + ) + hidden_states = hidden_states.to(dtype) + + hidden_states = hidden_states.flatten(2) + + out_proj, dropout = attn.to_out + hidden_states = out_proj(hidden_states) + hidden_states = dropout(hidden_states) + + return hidden_states + class CrossAttnAddedKVProcessor: def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): From 70dc50d5696e3d712f6c89e5f8c640aeff211f44 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 13:19:47 +0000 Subject: [PATCH 03/35] provide a way to skip checkpointing --- src/diffusers/models/cross_attention.py | 1 + .../models/sub_quadratic_attention.py | 20 ++++++++++++------- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 68f834a4747b..cfef098e3bae 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -299,6 +299,7 @@ def __call__( value, query_chunk_size=self.query_chunk_size, key_chunk_size=self.kv_chunk_size, + use_checkpoint=attn.training, ) hidden_states = hidden_states.to(dtype) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index 0d6db33508c6..9560b5b4719a 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -13,22 +13,26 @@ # Self-attention Does Not Need O(n2) Memory": # https://arxiv.org/abs/2112.05682v2 +from functools import partial import torch from torch.utils.checkpoint import checkpoint from ..utils.attention_slicing_utils import dynamic_slice, map_pt, scan import math +from typing import Optional def _query_chunk_attention(query_idx, query, key, value, - mask, bias, key_chunk_size=4096, + mask, bias, + key_chunk_size: Optional[int]=None, mask_calc_fn=None, bias_calc_fn=None, weights_calc_fn=None, - calc_fn_data=None): + calc_fn_data=None, + use_checkpoint=True): num_kv, num_heads, k_features = key.shape[-3:] v_features = value.shape[-1] num_q = query.shape[-3] - key_chunk_size = min(key_chunk_size, num_kv) + key_chunk_size = min(key_chunk_size or int(math.sqrt(num_kv)), num_kv) query = query / math.sqrt(k_features) def summarize_chunk(key_idx, query, key, value, mask, bias): @@ -53,6 +57,7 @@ def summarize_chunk(key_idx, query, key, value, mask, bias): exp_values = torch.einsum('...vhf,...qhv->...qhf', value, exp_weights) max_score = torch.einsum('...qhk->...qh', max_score) return exp_values, exp_weights.sum(dim=-1), max_score + summarizer = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk def chunk_scanner(chunk_idx): key_chunk = dynamic_slice(key, tuple([0] * (key.ndim - 3)) + (chunk_idx, 0, 0), @@ -80,7 +85,7 @@ def chunk_scanner(chunk_idx): else: raise TypeError(f'bias.shape[-1] == {bias.shape[-1]} must broadcast with key.shape[-3] == {num_kv}') - return checkpoint(summarize_chunk, chunk_idx, query, key_chunk, value_chunk, mask_chunk, bias_chunk) + return summarizer(chunk_idx, query, key_chunk, value_chunk, mask_chunk, bias_chunk) chunk_values, chunk_weights, chunk_max = map_pt( chunk_scanner, xs=torch.arange(0, num_kv, key_chunk_size)) @@ -98,11 +103,12 @@ def chunk_scanner(chunk_idx): def efficient_dot_product_attention(query, key, value, mask=None, bias=None, query_chunk_size=1024, - key_chunk_size=4096, + key_chunk_size: Optional[int] = None, bias_calc_fn=None, mask_calc_fn=None, weights_calc_fn=None, - calc_fn_data=None): + calc_fn_data=None, + use_checkpoint=True): """Computes efficient dot-product attention given query, key, and value. This is efficient version of attention presented in https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. @@ -173,7 +179,7 @@ def chunk_scanner(chunk_idx, _): return (chunk_idx + query_chunk_size, _query_chunk_attention(chunk_idx, query_chunk, key, value, mask_chunk, bias_chunk, key_chunk_size=key_chunk_size, bias_calc_fn=bias_calc_fn, mask_calc_fn=mask_calc_fn, - weights_calc_fn=weights_calc_fn, calc_fn_data=calc_fn_data)) + weights_calc_fn=weights_calc_fn, calc_fn_data=calc_fn_data, use_checkpoint=use_checkpoint)) _, res = scan(chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size)) rl = [res[i] for i in range(res.shape[0])] From c794f0badef5c0469777572134fca31ab3221263 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 16:29:59 +0000 Subject: [PATCH 04/35] MPS fixes; now working --- src/diffusers/models/sub_quadratic_attention.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index 9560b5b4719a..2dc34f357317 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -9,6 +9,7 @@ # sparse broadcasting for bias, mask, weights # flattened conditions for clarity # Hyungon Ryu (device arg fix) +# Alex Birch (MPS support) # implementation of: # Self-attention Does Not Need O(n2) Memory": # https://arxiv.org/abs/2112.05682v2 @@ -51,11 +52,13 @@ def summarize_chunk(key_idx, query, key, value, mask, bias): attn_weights = torch.where(mask, attn_weights, big_neg) if weights_calc_fn is not None: attn_weights = weights_calc_fn(query_idx, key_idx, attn_weights, calc_fn_data) + attn_weights = attn_weights.contiguous() if attn_weights.device.type == 'mps' else attn_weights max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() exp_weights = torch.exp(attn_weights - max_score) exp_values = torch.einsum('...vhf,...qhv->...qhf', value, exp_weights) max_score = torch.einsum('...qhk->...qh', max_score) + exp_values = exp_values.contiguous() if exp_values.device.type == 'mps' else exp_values return exp_values, exp_weights.sum(dim=-1), max_score summarizer = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk From 04a5cbe2865e83bfc93d5bdff2f8a756d1c98239 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 17:25:03 +0000 Subject: [PATCH 05/35] eliminate all einsums. assume 3D tensor [batch * num_heads, tokens, channels_per_head] in order to make use of batched matmuls. fuse multiply into matmul. breaks bias, mask in exchange for massive speedup. --- src/diffusers/models/cross_attention.py | 8 ++--- .../models/sub_quadratic_attention.py | 36 ++++++++++++------- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index cfef098e3bae..0dd45dd821e2 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -282,9 +282,9 @@ def __call__( key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) - query = query.unflatten(-1, (attn.heads, -1)) - key = key.unflatten(-1, (attn.heads, -1)) - value = value.unflatten(-1, (attn.heads, -1)) + query = query.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1) + key = key.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1) + value = value.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1) dtype = query.dtype # TODO: do we still need this given how we delay the division? @@ -303,7 +303,7 @@ def __call__( ) hidden_states = hidden_states.to(dtype) - hidden_states = hidden_states.flatten(2) + hidden_states = hidden_states.unflatten(0, (-1, attn.heads)).transpose(1,2).flatten(start_dim=2) out_proj, dropout = attn.to_out hidden_states = out_proj(hidden_states) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index 2dc34f357317..e5d0e94d9dfe 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -9,7 +9,10 @@ # sparse broadcasting for bias, mask, weights # flattened conditions for clarity # Hyungon Ryu (device arg fix) -# Alex Birch (MPS support) +# Alex Birch +# option to forego checkpointing (not needed during inference) +# MPS support +# optimizations (batched matmul, fused multiply) (at the expense of support for mask + bias) # implementation of: # Self-attention Does Not Need O(n2) Memory": # https://arxiv.org/abs/2112.05682v2 @@ -34,31 +37,40 @@ def _query_chunk_attention(query_idx, query, key, value, v_features = value.shape[-1] num_q = query.shape[-3] key_chunk_size = min(key_chunk_size or int(math.sqrt(num_kv)), num_kv) - query = query / math.sqrt(k_features) + scale = k_features ** -0.5 def summarize_chunk(key_idx, query, key, value, mask, bias): - attn_weights = torch.einsum('...qhd,...khd->...qhk', query, key) + attn_weights = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key.transpose(1,2), + alpha=scale, + beta=0, + ) if bias_calc_fn is not None: + raise "bias_calc_fn no longer supported" # lost support as a result of migrating to 3D tensors; needs to be reimplemented bias = bias_calc_fn(query_idx, key_idx, bias, attn_weights, calc_fn_data) if bias is not None: + raise "bias no longer supported" # lost support as a result of migrating to 3D tensors; needs to be reimplemented bias = torch.einsum('...hqk->...qhk', bias) attn_weights = attn_weights + bias if mask_calc_fn is not None: + raise "mask_calc_fn no longer supported" # lost support as a result of migrating to 3D tensors; needs to be reimplemented mask = mask_calc_fn(query_idx, key_idx, mask, attn_weights, calc_fn_data) if mask is not None: + raise "mask no longer supported" # lost support as a result of migrating to 3D tensors; needs to be reimplemented big_neg = torch.finfo(attn_weights.dtype).min big_neg = torch.tensor(big_neg, device=mask.device, dtype=torch.float32) mask = torch.einsum('...hqk->...qhk', mask) attn_weights = torch.where(mask, attn_weights, big_neg) if weights_calc_fn is not None: + raise "weights_calc_fn no longer supported" # lost support as a result of migrating to 3D tensors; needs to be reimplemented attn_weights = weights_calc_fn(query_idx, key_idx, attn_weights, calc_fn_data) - attn_weights = attn_weights.contiguous() if attn_weights.device.type == 'mps' else attn_weights max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() exp_weights = torch.exp(attn_weights - max_score) - exp_values = torch.einsum('...vhf,...qhv->...qhf', value, exp_weights) - max_score = torch.einsum('...qhk->...qh', max_score) - exp_values = exp_values.contiguous() if exp_values.device.type == 'mps' else exp_values + exp_values = torch.bmm(exp_weights, value) + max_score = max_score.squeeze(-1) return exp_values, exp_weights.sum(dim=-1), max_score summarizer = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk @@ -115,14 +127,13 @@ def efficient_dot_product_attention(query, key, value, """Computes efficient dot-product attention given query, key, and value. This is efficient version of attention presented in https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. - Note: query, key, value needn't have any batch dimensions. Args: query: queries for calculating attention with shape of - `[batch..., q_length, num_heads, qk_depth_per_head]`. + `[batch * num_heads, tokens, channels_per_head]`. key: keys for calculating attention with shape of - `[batch..., kv_length, num_heads, qk_depth_per_head]`. + `[batch * num_heads, tokens, channels_per_head]`. value: values to be used in attention with shape of - `[batch..., kv_length, num_heads, v_depth_per_head]`. + `[batch * num_heads, tokens, channels_per_head]`. bias: bias for the attention weights. This should be broadcastable to the shape `[batch..., num_heads, q_length, kv_length]`. This can be used for incorporating padding masks, proximity bias, etc. @@ -150,8 +161,9 @@ def efficient_dot_product_attention(query, key, value, calc_fn_data: optional pure data to pass to each per-chunk call of bias_calc_fn, mask_calc_fn, and weights_calc_fn. weights_calc_data: pure_data to pass with each call to weights_calc_fn + use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference) Returns: - Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`. + Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. """ num_q, num_heads, q_features = query.shape[-3:] num_kv = key.shape[-3] From b44fa12bbc344ee40d3227854e6ca757ffed6748 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 17:40:01 +0000 Subject: [PATCH 06/35] remove the bits that I broke in the pursuit of speed (mask, bias, weights_calc_fn, calc_fn_data) and unused vars --- src/diffusers/models/cross_attention.py | 2 +- .../models/sub_quadratic_attention.py | 106 +----------------- 2 files changed, 6 insertions(+), 102 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 0dd45dd821e2..73c3d135c872 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -276,7 +276,7 @@ def __call__( ): encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states - assert attention_mask is None, "attention-mask not currently tested for SubQuadraticCrossAttnProcessor." + assert attention_mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index e5d0e94d9dfe..7b06e449a492 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -25,21 +25,15 @@ from typing import Optional -def _query_chunk_attention(query_idx, query, key, value, - mask, bias, +def _query_chunk_attention(query, key, value, key_chunk_size: Optional[int]=None, - mask_calc_fn=None, - bias_calc_fn=None, - weights_calc_fn=None, - calc_fn_data=None, use_checkpoint=True): num_kv, num_heads, k_features = key.shape[-3:] v_features = value.shape[-1] - num_q = query.shape[-3] key_chunk_size = min(key_chunk_size or int(math.sqrt(num_kv)), num_kv) scale = k_features ** -0.5 - def summarize_chunk(key_idx, query, key, value, mask, bias): + def summarize_chunk(query, key, value): attn_weights = torch.baddbmm( torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), query, @@ -47,25 +41,6 @@ def summarize_chunk(key_idx, query, key, value, mask, bias): alpha=scale, beta=0, ) - if bias_calc_fn is not None: - raise "bias_calc_fn no longer supported" # lost support as a result of migrating to 3D tensors; needs to be reimplemented - bias = bias_calc_fn(query_idx, key_idx, bias, attn_weights, calc_fn_data) - if bias is not None: - raise "bias no longer supported" # lost support as a result of migrating to 3D tensors; needs to be reimplemented - bias = torch.einsum('...hqk->...qhk', bias) - attn_weights = attn_weights + bias - if mask_calc_fn is not None: - raise "mask_calc_fn no longer supported" # lost support as a result of migrating to 3D tensors; needs to be reimplemented - mask = mask_calc_fn(query_idx, key_idx, mask, attn_weights, calc_fn_data) - if mask is not None: - raise "mask no longer supported" # lost support as a result of migrating to 3D tensors; needs to be reimplemented - big_neg = torch.finfo(attn_weights.dtype).min - big_neg = torch.tensor(big_neg, device=mask.device, dtype=torch.float32) - mask = torch.einsum('...hqk->...qhk', mask) - attn_weights = torch.where(mask, attn_weights, big_neg) - if weights_calc_fn is not None: - raise "weights_calc_fn no longer supported" # lost support as a result of migrating to 3D tensors; needs to be reimplemented - attn_weights = weights_calc_fn(query_idx, key_idx, attn_weights, calc_fn_data) max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() exp_weights = torch.exp(attn_weights - max_score) @@ -80,27 +55,7 @@ def chunk_scanner(chunk_idx): value_chunk = dynamic_slice(value, tuple([0] * (value.ndim - 3)) + (chunk_idx, 0, 0), tuple(value.shape[:-3]) + (key_chunk_size, num_heads, v_features)) - if bias is None: - bias_chunk = None - elif bias.shape[-1] == 1: - bias_chunk = bias - elif bias.shape[-1] == num_kv: - bias_chunk = dynamic_slice(bias, tuple([0] * (bias.ndim - 3)) + (0, 0, chunk_idx), - tuple(bias.shape[:-3]) + (bias.shape[-3], bias.shape[-2], key_chunk_size)) - else: - raise TypeError(f'bias.shape[-1] == {bias.shape[-1]} must broadcast with key.shape[-3] == {num_kv}') - - if mask is None: - mask_chunk = None - elif mask.shape[-1] == 1: - mask_chunk = mask - elif mask.shape[-1] == num_kv: - mask_chunk = dynamic_slice(mask, tuple([0] * (mask.ndim - 3)) + (0, 0, chunk_idx), - tuple(mask.shape[:-3]) + (mask.shape[-3], mask.shape[-2], key_chunk_size)) - else: - raise TypeError(f'bias.shape[-1] == {bias.shape[-1]} must broadcast with key.shape[-3] == {num_kv}') - - return summarizer(chunk_idx, query, key_chunk, value_chunk, mask_chunk, bias_chunk) + return summarizer(query, key_chunk, value_chunk) chunk_values, chunk_weights, chunk_max = map_pt( chunk_scanner, xs=torch.arange(0, num_kv, key_chunk_size)) @@ -116,13 +71,8 @@ def chunk_scanner(chunk_idx): def efficient_dot_product_attention(query, key, value, - mask=None, bias=None, query_chunk_size=1024, key_chunk_size: Optional[int] = None, - bias_calc_fn=None, - mask_calc_fn=None, - weights_calc_fn=None, - calc_fn_data=None, use_checkpoint=True): """Computes efficient dot-product attention given query, key, and value. This is efficient version of attention presented in @@ -134,67 +84,21 @@ def efficient_dot_product_attention(query, key, value, `[batch * num_heads, tokens, channels_per_head]`. value: values to be used in attention with shape of `[batch * num_heads, tokens, channels_per_head]`. - bias: bias for the attention weights. This should be broadcastable to the - shape `[batch..., num_heads, q_length, kv_length]`. - This can be used for incorporating padding masks, proximity bias, etc. - mask: mask for the attention weights. This should be broadcastable to the - shape `[batch..., num_heads, q_length, kv_length]`. - Attention weights are masked out if their corresponding mask value - is `False`. query_chunk_size: int: query chunks size key_chunk_size: int: key chunks size - bias_calc_fn: a bias calculation callback for each chunk, of form - `(q_offset, k_offset, bias_chunk, attn_weights, calc_fn_data) -> bias`. - This can be used for incorporating causal masks, padding masks, - proximity bias, etc. - mask_calc_fn: a mask calculation callback for each chunk, of form - `(q_offset, k_offset, mask_chunk, attn_weights, calc_fn_data) -> mask`. - This can be used for incorporating causal or other large masks. - Attention weights are masked out if their corresponding mask value - is `False`. - weights_calc_fn: a general attn_weights callback for each chunk, of form - `(q_offset, k_offset, attn_weights, calc_fn_data) -> attn_weights`. - attn_weights has shape of - `[batch..., q_chunk_size, num_heads, k_chunk_size]`. - This can be used to implement complex weights processing in a memory - efficient way. - calc_fn_data: optional pure data to pass to each per-chunk call of - bias_calc_fn, mask_calc_fn, and weights_calc_fn. - weights_calc_data: pure_data to pass with each call to weights_calc_fn use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference) Returns: Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. """ num_q, num_heads, q_features = query.shape[-3:] - num_kv = key.shape[-3] def chunk_scanner(chunk_idx, _): query_chunk = dynamic_slice(query, tuple([0] * (query.ndim - 3)) + (chunk_idx, 0, 0), tuple(query.shape[:-3]) + (min(query_chunk_size, num_q), num_heads, q_features)) - if mask is None: - mask_chunk = None - elif mask.shape[-2] == 1: - mask_chunk = mask - elif mask.shape[-2] == num_q: - mask_chunk = dynamic_slice(mask, tuple([0] * (mask.ndim - 3)) + (0, chunk_idx, 0), - tuple(mask.shape[:-3]) + (mask.shape[-3], min(query_chunk_size, num_q), mask.shape[-1])) - else: - raise TypeError(f'mask.shape[-2] == {mask.shape[-2]} must broadcast with query.shape[-3] == {num_q}') - - if bias is None: - bias_chunk = None - elif bias.shape[-2] == 1: - bias_chunk = bias - elif bias.shape[-2] == num_q: - bias_chunk = dynamic_slice(bias, tuple([0] * (bias.ndim - 3)) + (0, chunk_idx, 0), - tuple(bias.shape[:-3]) + (bias.shape[-3], min(query_chunk_size, num_q), bias.shape[-1])) - else: - raise TypeError(f'bias.shape[-2] == {bias.shape[-2]} must broadcast with query.shape[-3] == {num_q}') return (chunk_idx + query_chunk_size, - _query_chunk_attention(chunk_idx, query_chunk, key, value, mask_chunk, bias_chunk, key_chunk_size=key_chunk_size, - bias_calc_fn=bias_calc_fn, mask_calc_fn=mask_calc_fn, - weights_calc_fn=weights_calc_fn, calc_fn_data=calc_fn_data, use_checkpoint=use_checkpoint)) + _query_chunk_attention(query_chunk, key, value, key_chunk_size=key_chunk_size, + use_checkpoint=use_checkpoint)) _, res = scan(chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size)) rl = [res[i] for i in range(res.shape[0])] From 8694703562125ca1a4d9091ef761d809ae94c665 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 17:40:49 +0000 Subject: [PATCH 07/35] clarify comment; verified that upcast_attention is indeed still helpful for SD 2.1. but remove value float32, having established that it works without. --- src/diffusers/models/cross_attention.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 73c3d135c872..cbb20141aa81 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -287,11 +287,10 @@ def __call__( value = value.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1) dtype = query.dtype - # TODO: do we still need this given how we delay the division? + # TODO: do we still need to do *everything* in float32, given how we delay the division? if attn.upcast_attention: query = query.float() key = key.float() - value = value.float() hidden_states = efficient_dot_product_attention( query, From 5bfe96d31fc607546bb42bb0240b6fe6664a8ae9 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 17:41:00 +0000 Subject: [PATCH 08/35] add TODO about softmax --- src/diffusers/models/cross_attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index cbb20141aa81..ecf1b624547c 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -288,6 +288,7 @@ def __call__( dtype = query.dtype # TODO: do we still need to do *everything* in float32, given how we delay the division? + # TODO: do we need to support upcast_softmax too? SD 2.1 seems to work without it if attn.upcast_attention: query = query.float() key = key.float() From da8901b2b8184d52ea06c25456e86d13e3a65cc1 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 18:54:54 +0000 Subject: [PATCH 09/35] typings --- .../models/sub_quadratic_attention.py | 58 ++++++++++++++----- .../utils/attention_slicing_utils.py | 45 +++++++++++--- 2 files changed, 80 insertions(+), 23 deletions(-) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index 7b06e449a492..43bfd3f972dc 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -13,27 +13,46 @@ # option to forego checkpointing (not needed during inference) # MPS support # optimizations (batched matmul, fused multiply) (at the expense of support for mask + bias) +# typings # implementation of: # Self-attention Does Not Need O(n2) Memory": # https://arxiv.org/abs/2112.05682v2 from functools import partial import torch +from torch import Tensor from torch.utils.checkpoint import checkpoint -from ..utils.attention_slicing_utils import dynamic_slice, map_pt, scan +from ..utils.attention_slicing_utils import dynamic_slice, map_pt, scan, AttnChunk import math -from typing import Optional +from typing import Optional, NamedTuple, Protocol, List -def _query_chunk_attention(query, key, value, - key_chunk_size: Optional[int]=None, - use_checkpoint=True): +class SummarizeChunk(Protocol): + def __call__( + self, + query: Tensor, + key: Tensor, + value: Tensor, + ) -> AttnChunk: ... + + +def _query_chunk_attention( + query: Tensor, + key: Tensor, + value: Tensor, + key_chunk_size: Optional[int] = None, + use_checkpoint = True, +): num_kv, num_heads, k_features = key.shape[-3:] v_features = value.shape[-1] key_chunk_size = min(key_chunk_size or int(math.sqrt(num_kv)), num_kv) scale = k_features ** -0.5 - def summarize_chunk(query, key, value): + def summarize_chunk( + query: Tensor, + key: Tensor, + value: Tensor, + ) -> AttnChunk: attn_weights = torch.baddbmm( torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), query, @@ -46,10 +65,10 @@ def summarize_chunk(query, key, value): exp_weights = torch.exp(attn_weights - max_score) exp_values = torch.bmm(exp_weights, value) max_score = max_score.squeeze(-1) - return exp_values, exp_weights.sum(dim=-1), max_score - summarizer = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk + return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) + summarizer: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk - def chunk_scanner(chunk_idx): + def chunk_scanner(chunk_idx: int) -> AttnChunk: key_chunk = dynamic_slice(key, tuple([0] * (key.ndim - 3)) + (chunk_idx, 0, 0), tuple(key.shape[:-3]) + (key_chunk_size, num_heads, k_features)) value_chunk = dynamic_slice(value, tuple([0] * (value.ndim - 3)) + (chunk_idx, 0, 0), @@ -69,11 +88,18 @@ def chunk_scanner(chunk_idx): all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) return all_values / all_weights +class ScannedChunk(NamedTuple): + chunk_idx: int + attn_chunk: AttnChunk -def efficient_dot_product_attention(query, key, value, - query_chunk_size=1024, - key_chunk_size: Optional[int] = None, - use_checkpoint=True): +def efficient_dot_product_attention( + query: Tensor, + key: Tensor, + value: Tensor, + query_chunk_size=1024, + key_chunk_size: Optional[int] = None, + use_checkpoint=True, +): """Computes efficient dot-product attention given query, key, and value. This is efficient version of attention presented in https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. @@ -92,14 +118,14 @@ def efficient_dot_product_attention(query, key, value, """ num_q, num_heads, q_features = query.shape[-3:] - def chunk_scanner(chunk_idx, _): + def chunk_scanner(chunk_idx: int, _) -> AttnChunk: query_chunk = dynamic_slice(query, tuple([0] * (query.ndim - 3)) + (chunk_idx, 0, 0), tuple(query.shape[:-3]) + (min(query_chunk_size, num_q), num_heads, q_features)) - return (chunk_idx + query_chunk_size, + return ScannedChunk(chunk_idx + query_chunk_size, _query_chunk_attention(query_chunk, key, value, key_chunk_size=key_chunk_size, use_checkpoint=use_checkpoint)) _, res = scan(chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size)) - rl = [res[i] for i in range(res.shape[0])] + rl: List[Tensor] = [res[i] for i in range(res.shape[0])] return torch.cat(rl, dim=-3) diff --git a/src/diffusers/utils/attention_slicing_utils.py b/src/diffusers/utils/attention_slicing_utils.py index 44a7af9b773d..cbac0388b37b 100644 --- a/src/diffusers/utils/attention_slicing_utils.py +++ b/src/diffusers/utils/attention_slicing_utils.py @@ -5,31 +5,62 @@ # credit: # Amin Rezaei (primary author) # Hyungon Ryu (device arg fix) +# Alex Birch (typings) # implementation of: # Self-attention Does Not Need O(n2) Memory": # https://arxiv.org/abs/2112.05682v2 import torch import numpy as np +from torch import Tensor +from typing import Protocol, NamedTuple, Iterable, Optional, List, Tuple -def dynamic_slice(x, starts, sizes): +def dynamic_slice( + x: Tensor, + starts: List[int], + sizes: List[int], +) -> Tensor: # start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i]) - starts = [np.clip(starts[i], 0, x.shape[i] - sizes[i]) for i in range(len(starts))] + starts: List[int] = [np.clip(starts[i], 0, x.shape[i] - sizes[i]) for i in range(len(starts))] for i, (start, size) in enumerate(zip(starts, sizes)): x = torch.index_select(x, i, torch.tensor(range(start, start + size), device=x.device)) return x +class AttnChunk(NamedTuple): + exp_values: Tensor + exp_weights_sum: Tensor + max_score: Tensor -def map_pt(f, xs): +class ChunkScanner(Protocol): + def __call__(self, chunk_idx: int) -> AttnChunk: ... + +def map_pt(f: ChunkScanner, xs: List[int]) -> Tuple[Tensor, ...]: t = [f(x) for x in xs] return tuple(map(torch.stack, zip(*t))) -def scan(f, init, xs, length=None): +class ScanOutput(NamedTuple): + carry: int + y: Tensor + +class ScanCallback(Protocol): + def __call__( + self, + carry: int, + value: Tensor, + ) -> ScanOutput: ... + + +def scan( + f: ScanCallback, + init: int, + xs: Optional[Iterable[Tensor]], + length: Optional[int] = None +): if xs is None: - xs = [None] * length - carry = init - ys = [] + xs: List[Tensor] = [None] * length + carry: int = init + ys: List[Tensor] = [] for x in xs: carry, y = f(carry, x) ys.append(y) From 0c4d82f4034014fbbfe31a5b2f7370e517c68a64 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 18:57:40 +0000 Subject: [PATCH 10/35] simplify protocols --- src/diffusers/utils/attention_slicing_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/attention_slicing_utils.py b/src/diffusers/utils/attention_slicing_utils.py index cbac0388b37b..09b6e4bf1f1d 100644 --- a/src/diffusers/utils/attention_slicing_utils.py +++ b/src/diffusers/utils/attention_slicing_utils.py @@ -32,7 +32,8 @@ class AttnChunk(NamedTuple): max_score: Tensor class ChunkScanner(Protocol): - def __call__(self, chunk_idx: int) -> AttnChunk: ... + @staticmethod + def __call__(chunk_idx: int) -> AttnChunk: ... def map_pt(f: ChunkScanner, xs: List[int]) -> Tuple[Tensor, ...]: t = [f(x) for x in xs] @@ -44,8 +45,8 @@ class ScanOutput(NamedTuple): y: Tensor class ScanCallback(Protocol): + @staticmethod def __call__( - self, carry: int, value: Tensor, ) -> ScanOutput: ... From c5e8e31c835039f4c845f1cdc5f2d4f5d0725b97 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 18:57:44 +0000 Subject: [PATCH 11/35] remove unused --- src/diffusers/utils/attention_slicing_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/utils/attention_slicing_utils.py b/src/diffusers/utils/attention_slicing_utils.py index 09b6e4bf1f1d..3ebfcf3e7c1f 100644 --- a/src/diffusers/utils/attention_slicing_utils.py +++ b/src/diffusers/utils/attention_slicing_utils.py @@ -20,7 +20,6 @@ def dynamic_slice( starts: List[int], sizes: List[int], ) -> Tensor: - # start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i]) starts: List[int] = [np.clip(starts[i], 0, x.shape[i] - sizes[i]) for i in range(len(starts))] for i, (start, size) in enumerate(zip(starts, sizes)): x = torch.index_select(x, i, torch.tensor(range(start, start + size), device=x.device)) From b16edc9527a712e979a82f55087ef8488518eb56 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 19:23:57 +0000 Subject: [PATCH 12/35] simplify protocol --- src/diffusers/models/sub_quadratic_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index 43bfd3f972dc..33f76fd92ba4 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -28,8 +28,8 @@ class SummarizeChunk(Protocol): + @staticmethod def __call__( - self, query: Tensor, key: Tensor, value: Tensor, From b7fc3a8dc77d6d6623ea2501a445702b3d9d225d Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 19:24:26 +0000 Subject: [PATCH 13/35] fix tensor shape destructuring --- .../models/sub_quadratic_attention.py | 48 ++++++++++++------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index 33f76fd92ba4..211d65967287 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -43,10 +43,10 @@ def _query_chunk_attention( key_chunk_size: Optional[int] = None, use_checkpoint = True, ): - num_kv, num_heads, k_features = key.shape[-3:] - v_features = value.shape[-1] - key_chunk_size = min(key_chunk_size or int(math.sqrt(num_kv)), num_kv) - scale = k_features ** -0.5 + batch_x_heads, k_tokens, k_channels_per_head = key.shape[-3:] + v_channels_per_head = value.shape[-1] + key_chunk_size = min(key_chunk_size or int(math.sqrt(k_tokens)), k_tokens) + scale = k_channels_per_head ** -0.5 def summarize_chunk( query: Tensor, @@ -69,15 +69,21 @@ def summarize_chunk( summarizer: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk def chunk_scanner(chunk_idx: int) -> AttnChunk: - key_chunk = dynamic_slice(key, tuple([0] * (key.ndim - 3)) + (chunk_idx, 0, 0), - tuple(key.shape[:-3]) + (key_chunk_size, num_heads, k_features)) - value_chunk = dynamic_slice(value, tuple([0] * (value.ndim - 3)) + (chunk_idx, 0, 0), - tuple(value.shape[:-3]) + (key_chunk_size, num_heads, v_features)) + key_chunk = dynamic_slice( + key, + tuple([0] * (key.ndim - 3)) + (chunk_idx, 0, 0), + tuple(key.shape[:-3]) + (batch_x_heads, key_chunk_size, k_channels_per_head) + ) + value_chunk = dynamic_slice( + value, + tuple([0] * (value.ndim - 3)) + (chunk_idx, 0, 0), + tuple(value.shape[:-3]) + (batch_x_heads, key_chunk_size, v_channels_per_head) + ) return summarizer(query, key_chunk, value_chunk) chunk_values, chunk_weights, chunk_max = map_pt( - chunk_scanner, xs=torch.arange(0, num_kv, key_chunk_size)) + chunk_scanner, xs=torch.arange(0, k_tokens, key_chunk_size)) global_max, _ = torch.max(chunk_max, 0, keepdim=True) max_diffs = torch.exp(chunk_max - global_max) @@ -116,16 +122,26 @@ def efficient_dot_product_attention( Returns: Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. """ - num_q, num_heads, q_features = query.shape[-3:] + batch_x_heads, q_tokens, q_channels_per_head = query.shape[-3:] def chunk_scanner(chunk_idx: int, _) -> AttnChunk: - query_chunk = dynamic_slice(query, tuple([0] * (query.ndim - 3)) + (chunk_idx, 0, 0), - tuple(query.shape[:-3]) + (min(query_chunk_size, num_q), num_heads, q_features)) + query_chunk = dynamic_slice( + query, + tuple([0] * (query.ndim - 3)) + (chunk_idx, 0, 0), + tuple(query.shape[:-3]) + (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) + ) - return ScannedChunk(chunk_idx + query_chunk_size, - _query_chunk_attention(query_chunk, key, value, key_chunk_size=key_chunk_size, - use_checkpoint=use_checkpoint)) + return ScannedChunk( + chunk_idx + query_chunk_size, + _query_chunk_attention( + query_chunk, + key, + value, + key_chunk_size=key_chunk_size, + use_checkpoint=use_checkpoint, + ) + ) - _, res = scan(chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size)) + _, res = scan(chunk_scanner, init=0, xs=None, length=math.ceil(q_tokens / query_chunk_size)) rl: List[Tensor] = [res[i] for i in range(res.shape[0])] return torch.cat(rl, dim=-3) From 8f003c242cfdc324cdb3484ae3f325c6c817abc6 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 20:07:26 +0000 Subject: [PATCH 14/35] simplify dynamic_slice --- src/diffusers/utils/attention_slicing_utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/diffusers/utils/attention_slicing_utils.py b/src/diffusers/utils/attention_slicing_utils.py index 3ebfcf3e7c1f..a3b441ca046d 100644 --- a/src/diffusers/utils/attention_slicing_utils.py +++ b/src/diffusers/utils/attention_slicing_utils.py @@ -10,7 +10,6 @@ # Self-attention Does Not Need O(n2) Memory": # https://arxiv.org/abs/2112.05682v2 import torch -import numpy as np from torch import Tensor from typing import Protocol, NamedTuple, Iterable, Optional, List, Tuple @@ -20,10 +19,8 @@ def dynamic_slice( starts: List[int], sizes: List[int], ) -> Tensor: - starts: List[int] = [np.clip(starts[i], 0, x.shape[i] - sizes[i]) for i in range(len(starts))] - for i, (start, size) in enumerate(zip(starts, sizes)): - x = torch.index_select(x, i, torch.tensor(range(start, start + size), device=x.device)) - return x + slicing = [slice(start, start + size + 1) for start, size in zip(starts, sizes)] + return x[slicing] class AttnChunk(NamedTuple): exp_values: Tensor From 1334670468358515e5d227754412b154af56b1cd Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 20:18:48 +0000 Subject: [PATCH 15/35] simplify chunk scanning --- .../models/sub_quadratic_attention.py | 23 ++++++++-------- .../utils/attention_slicing_utils.py | 27 ------------------- 2 files changed, 12 insertions(+), 38 deletions(-) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index 211d65967287..3d2f6faebd91 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -124,24 +124,25 @@ def efficient_dot_product_attention( """ batch_x_heads, q_tokens, q_channels_per_head = query.shape[-3:] - def chunk_scanner(chunk_idx: int, _) -> AttnChunk: + def chunk_scanner(chunk_idx: int) -> Tensor: query_chunk = dynamic_slice( query, tuple([0] * (query.ndim - 3)) + (chunk_idx, 0, 0), tuple(query.shape[:-3]) + (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) ) - return ScannedChunk( - chunk_idx + query_chunk_size, - _query_chunk_attention( - query_chunk, - key, - value, - key_chunk_size=key_chunk_size, - use_checkpoint=use_checkpoint, - ) + return _query_chunk_attention( + query_chunk, + key, + value, + key_chunk_size=key_chunk_size, + use_checkpoint=use_checkpoint, ) + + res = torch.stack([ + chunk_scanner(i * query_chunk_size) for i in range(math.ceil(q_tokens / query_chunk_size)) + ]) - _, res = scan(chunk_scanner, init=0, xs=None, length=math.ceil(q_tokens / query_chunk_size)) + # _, res = scan(chunk_scanner, init=0, xs=None, length=math.ceil(q_tokens / query_chunk_size)) rl: List[Tensor] = [res[i] for i in range(res.shape[0])] return torch.cat(rl, dim=-3) diff --git a/src/diffusers/utils/attention_slicing_utils.py b/src/diffusers/utils/attention_slicing_utils.py index a3b441ca046d..fd822272effd 100644 --- a/src/diffusers/utils/attention_slicing_utils.py +++ b/src/diffusers/utils/attention_slicing_utils.py @@ -35,30 +35,3 @@ def map_pt(f: ChunkScanner, xs: List[int]) -> Tuple[Tensor, ...]: t = [f(x) for x in xs] return tuple(map(torch.stack, zip(*t))) - -class ScanOutput(NamedTuple): - carry: int - y: Tensor - -class ScanCallback(Protocol): - @staticmethod - def __call__( - carry: int, - value: Tensor, - ) -> ScanOutput: ... - - -def scan( - f: ScanCallback, - init: int, - xs: Optional[Iterable[Tensor]], - length: Optional[int] = None -): - if xs is None: - xs: List[Tensor] = [None] * length - carry: int = init - ys: List[Tensor] = [] - for x in xs: - carry, y = f(carry, x) - ys.append(y) - return carry, torch.stack(ys) \ No newline at end of file From 0676c1315a762ce76751edb5b355e6630de56a0d Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 20:45:01 +0000 Subject: [PATCH 16/35] inline sole use of map_pt function --- .../models/sub_quadratic_attention.py | 15 ++++++++----- .../utils/attention_slicing_utils.py | 22 +++---------------- 2 files changed, 13 insertions(+), 24 deletions(-) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index 3d2f6faebd91..591e721ec588 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -22,10 +22,14 @@ import torch from torch import Tensor from torch.utils.checkpoint import checkpoint -from ..utils.attention_slicing_utils import dynamic_slice, map_pt, scan, AttnChunk +from ..utils.attention_slicing_utils import dynamic_slice import math from typing import Optional, NamedTuple, Protocol, List +class AttnChunk(NamedTuple): + exp_values: Tensor + exp_weights_sum: Tensor + max_score: Tensor class SummarizeChunk(Protocol): @staticmethod @@ -35,7 +39,6 @@ def __call__( value: Tensor, ) -> AttnChunk: ... - def _query_chunk_attention( query: Tensor, key: Tensor, @@ -82,8 +85,11 @@ def chunk_scanner(chunk_idx: int) -> AttnChunk: return summarizer(query, key_chunk, value_chunk) - chunk_values, chunk_weights, chunk_max = map_pt( - chunk_scanner, xs=torch.arange(0, k_tokens, key_chunk_size)) + chunks: List[AttnChunk] = [ + chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, key_chunk_size) + ] + acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) + chunk_values, chunk_weights, chunk_max = acc_chunk global_max, _ = torch.max(chunk_max, 0, keepdim=True) max_diffs = torch.exp(chunk_max - global_max) @@ -143,6 +149,5 @@ def chunk_scanner(chunk_idx: int) -> Tensor: chunk_scanner(i * query_chunk_size) for i in range(math.ceil(q_tokens / query_chunk_size)) ]) - # _, res = scan(chunk_scanner, init=0, xs=None, length=math.ceil(q_tokens / query_chunk_size)) rl: List[Tensor] = [res[i] for i in range(res.shape[0])] return torch.cat(rl, dim=-3) diff --git a/src/diffusers/utils/attention_slicing_utils.py b/src/diffusers/utils/attention_slicing_utils.py index fd822272effd..1ba4eecb61a0 100644 --- a/src/diffusers/utils/attention_slicing_utils.py +++ b/src/diffusers/utils/attention_slicing_utils.py @@ -5,14 +5,12 @@ # credit: # Amin Rezaei (primary author) # Hyungon Ryu (device arg fix) -# Alex Birch (typings) +# Alex Birch (typings, deleted everything) # implementation of: # Self-attention Does Not Need O(n2) Memory": # https://arxiv.org/abs/2112.05682v2 -import torch from torch import Tensor -from typing import Protocol, NamedTuple, Iterable, Optional, List, Tuple - +from typing import List def dynamic_slice( x: Tensor, @@ -20,18 +18,4 @@ def dynamic_slice( sizes: List[int], ) -> Tensor: slicing = [slice(start, start + size + 1) for start, size in zip(starts, sizes)] - return x[slicing] - -class AttnChunk(NamedTuple): - exp_values: Tensor - exp_weights_sum: Tensor - max_score: Tensor - -class ChunkScanner(Protocol): - @staticmethod - def __call__(chunk_idx: int) -> AttnChunk: ... - -def map_pt(f: ChunkScanner, xs: List[int]) -> Tuple[Tensor, ...]: - t = [f(x) for x in xs] - return tuple(map(torch.stack, zip(*t))) - + return x[slicing] \ No newline at end of file From 264dfb76edcc47b76f8b97ead6b84d740b233684 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 20:49:30 +0000 Subject: [PATCH 17/35] simplify --- src/diffusers/models/sub_quadratic_attention.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index 591e721ec588..d5e219418f82 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -148,6 +148,4 @@ def chunk_scanner(chunk_idx: int) -> Tensor: res = torch.stack([ chunk_scanner(i * query_chunk_size) for i in range(math.ceil(q_tokens / query_chunk_size)) ]) - - rl: List[Tensor] = [res[i] for i in range(res.shape[0])] - return torch.cat(rl, dim=-3) + return res.flatten(end_dim=-3) From 205f55b837a3e6ff26f3a9513a4a41779ea0cf89 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 20:52:32 +0000 Subject: [PATCH 18/35] no longer using original utilities from memory-efficient-attention repository --- .../models/sub_quadratic_attention.py | 2 +- .../utils/attention_slicing_utils.py | 21 ------------------- src/diffusers/utils/dynamic_slice.py | 10 +++++++++ 3 files changed, 11 insertions(+), 22 deletions(-) delete mode 100644 src/diffusers/utils/attention_slicing_utils.py create mode 100644 src/diffusers/utils/dynamic_slice.py diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index d5e219418f82..ee945b79d5f9 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -22,7 +22,7 @@ import torch from torch import Tensor from torch.utils.checkpoint import checkpoint -from ..utils.attention_slicing_utils import dynamic_slice +from ..utils.dynamic_slice import dynamic_slice import math from typing import Optional, NamedTuple, Protocol, List diff --git a/src/diffusers/utils/attention_slicing_utils.py b/src/diffusers/utils/attention_slicing_utils.py deleted file mode 100644 index 1ba4eecb61a0..000000000000 --- a/src/diffusers/utils/attention_slicing_utils.py +++ /dev/null @@ -1,21 +0,0 @@ -# original source: -# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/utils.py -# license: -# unspecified -# credit: -# Amin Rezaei (primary author) -# Hyungon Ryu (device arg fix) -# Alex Birch (typings, deleted everything) -# implementation of: -# Self-attention Does Not Need O(n2) Memory": -# https://arxiv.org/abs/2112.05682v2 -from torch import Tensor -from typing import List - -def dynamic_slice( - x: Tensor, - starts: List[int], - sizes: List[int], -) -> Tensor: - slicing = [slice(start, start + size + 1) for start, size in zip(starts, sizes)] - return x[slicing] \ No newline at end of file diff --git a/src/diffusers/utils/dynamic_slice.py b/src/diffusers/utils/dynamic_slice.py new file mode 100644 index 000000000000..366327d0337c --- /dev/null +++ b/src/diffusers/utils/dynamic_slice.py @@ -0,0 +1,10 @@ +from torch import Tensor +from typing import List + +def dynamic_slice( + x: Tensor, + starts: List[int], + sizes: List[int], +) -> Tensor: + slicing = [slice(start, start + size + 1) for start, size in zip(starts, sizes)] + return x[slicing] \ No newline at end of file From 1880c0ed1694086c690f6f1d386b4b97634135ce Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 21:04:52 +0000 Subject: [PATCH 19/35] fix query slicing --- src/diffusers/models/sub_quadratic_attention.py | 8 ++++---- src/diffusers/utils/dynamic_slice.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index ee945b79d5f9..ad5e192936da 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -133,7 +133,7 @@ def efficient_dot_product_attention( def chunk_scanner(chunk_idx: int) -> Tensor: query_chunk = dynamic_slice( query, - tuple([0] * (query.ndim - 3)) + (chunk_idx, 0, 0), + tuple([0] * (query.ndim - 3)) + (0, chunk_idx, 0), tuple(query.shape[:-3]) + (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) ) @@ -145,7 +145,7 @@ def chunk_scanner(chunk_idx: int) -> Tensor: use_checkpoint=use_checkpoint, ) - res = torch.stack([ + res = torch.cat([ chunk_scanner(i * query_chunk_size) for i in range(math.ceil(q_tokens / query_chunk_size)) - ]) - return res.flatten(end_dim=-3) + ], dim=1) + return res diff --git a/src/diffusers/utils/dynamic_slice.py b/src/diffusers/utils/dynamic_slice.py index 366327d0337c..046678bb51f4 100644 --- a/src/diffusers/utils/dynamic_slice.py +++ b/src/diffusers/utils/dynamic_slice.py @@ -6,5 +6,5 @@ def dynamic_slice( starts: List[int], sizes: List[int], ) -> Tensor: - slicing = [slice(start, start + size + 1) for start, size in zip(starts, sizes)] + slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] return x[slicing] \ No newline at end of file From 8603c3077a4da33067a70fd84ed7c562b2a38bc7 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 21:07:19 +0000 Subject: [PATCH 20/35] fix kv chunking --- src/diffusers/models/sub_quadratic_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index ad5e192936da..bfc114d58cd2 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -74,12 +74,12 @@ def summarize_chunk( def chunk_scanner(chunk_idx: int) -> AttnChunk: key_chunk = dynamic_slice( key, - tuple([0] * (key.ndim - 3)) + (chunk_idx, 0, 0), + tuple([0] * (key.ndim - 3)) + (0, chunk_idx, 0), tuple(key.shape[:-3]) + (batch_x_heads, key_chunk_size, k_channels_per_head) ) value_chunk = dynamic_slice( value, - tuple([0] * (value.ndim - 3)) + (chunk_idx, 0, 0), + tuple([0] * (value.ndim - 3)) + (0, chunk_idx, 0), tuple(value.shape[:-3]) + (batch_x_heads, key_chunk_size, v_channels_per_head) ) From 96e0d8c8bbcfc0ccf74400aa0d9dc7c9f4c353bd Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 21:11:02 +0000 Subject: [PATCH 21/35] simplify dynamic slicing --- src/diffusers/models/sub_quadratic_attention.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index bfc114d58cd2..3481ec72e76d 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -74,13 +74,13 @@ def summarize_chunk( def chunk_scanner(chunk_idx: int) -> AttnChunk: key_chunk = dynamic_slice( key, - tuple([0] * (key.ndim - 3)) + (0, chunk_idx, 0), - tuple(key.shape[:-3]) + (batch_x_heads, key_chunk_size, k_channels_per_head) + (0, chunk_idx, 0), + (batch_x_heads, key_chunk_size, k_channels_per_head) ) value_chunk = dynamic_slice( value, - tuple([0] * (value.ndim - 3)) + (0, chunk_idx, 0), - tuple(value.shape[:-3]) + (batch_x_heads, key_chunk_size, v_channels_per_head) + (0, chunk_idx, 0), + (batch_x_heads, key_chunk_size, v_channels_per_head) ) return summarizer(query, key_chunk, value_chunk) @@ -128,13 +128,13 @@ def efficient_dot_product_attention( Returns: Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. """ - batch_x_heads, q_tokens, q_channels_per_head = query.shape[-3:] + batch_x_heads, q_tokens, q_channels_per_head = query.shape def chunk_scanner(chunk_idx: int) -> Tensor: query_chunk = dynamic_slice( query, - tuple([0] * (query.ndim - 3)) + (0, chunk_idx, 0), - tuple(query.shape[:-3]) + (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) + (0, chunk_idx, 0), + (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) ) return _query_chunk_attention( From 63ca66d12f2b39f80339096d0a28ea3bafb7436e Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 21:12:33 +0000 Subject: [PATCH 22/35] removed bias, mask, weights, calc_fn, and the conditions controlling them --- src/diffusers/models/sub_quadratic_attention.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index 3481ec72e76d..0f393e65f7f8 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -4,10 +4,6 @@ # unspecified # credit: # Amin Rezaei (primary author) -# xloem -# calc_fn -# sparse broadcasting for bias, mask, weights -# flattened conditions for clarity # Hyungon Ryu (device arg fix) # Alex Birch # option to forego checkpointing (not needed during inference) From f4c0bf4de656f1acb1a1991e0185b8c9bb2c21c5 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 21:12:47 +0000 Subject: [PATCH 23/35] device arg fix no longer included --- src/diffusers/models/sub_quadratic_attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index 0f393e65f7f8..0e3cb5520e6b 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -4,7 +4,6 @@ # unspecified # credit: # Amin Rezaei (primary author) -# Hyungon Ryu (device arg fix) # Alex Birch # option to forego checkpointing (not needed during inference) # MPS support From 624123f693202b79b7c3422e61ac5b45f75085e7 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 21:17:00 +0000 Subject: [PATCH 24/35] simplify --- src/diffusers/models/sub_quadratic_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index 0e3cb5520e6b..99c323d72852 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -17,9 +17,9 @@ import torch from torch import Tensor from torch.utils.checkpoint import checkpoint -from ..utils.dynamic_slice import dynamic_slice import math from typing import Optional, NamedTuple, Protocol, List +from ..utils.dynamic_slice import dynamic_slice class AttnChunk(NamedTuple): exp_values: Tensor @@ -41,8 +41,8 @@ def _query_chunk_attention( key_chunk_size: Optional[int] = None, use_checkpoint = True, ): - batch_x_heads, k_tokens, k_channels_per_head = key.shape[-3:] - v_channels_per_head = value.shape[-1] + batch_x_heads, k_tokens, k_channels_per_head = key.shape + _, _, v_channels_per_head = value.shape key_chunk_size = min(key_chunk_size or int(math.sqrt(k_tokens)), k_tokens) scale = k_channels_per_head ** -0.5 From 5b92dab5806eba6ff5691b6b614ee9da98817bd7 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 26 Dec 2022 21:17:19 +0000 Subject: [PATCH 25/35] clarify attributions now that algorithm has been substantially rewritten --- src/diffusers/models/sub_quadratic_attention.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index 99c323d72852..84b073d4a06a 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -3,12 +3,8 @@ # license: # unspecified # credit: -# Amin Rezaei (primary author) -# Alex Birch -# option to forego checkpointing (not needed during inference) -# MPS support -# optimizations (batched matmul, fused multiply) (at the expense of support for mask + bias) -# typings +# Amin Rezaei (original author) +# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) # implementation of: # Self-attention Does Not Need O(n2) Memory": # https://arxiv.org/abs/2112.05682v2 From 60f0a5e3f29a2aee55c06bc616af933872664327 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Wed, 28 Dec 2022 00:00:46 +0000 Subject: [PATCH 26/35] add chunk_threshold_bytes to let you specify your safe memory limit, to prefer fast-path whenever unchunked attention would fit into memory. add kv_chunk_size_min to control the kv_chunk_size=None behaviour, so that sqrt(key_tokens) does not pick too small of a chunk size --- src/diffusers/models/cross_attention.py | 59 +++++++++++++++---- .../models/sub_quadratic_attention.py | 22 ++++--- 2 files changed, 61 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index ecf1b624547c..51a60be0e78b 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -151,13 +151,22 @@ def set_subquadratic_attention( self, query_chunk_size = 1024, kv_chunk_size: Optional[int] = None, + kv_chunk_size_min: Optional[int] = None, + chunk_threshold_bytes: Optional[int] = None, ): r""" Args: query_chunk_size (`int`, *optional*, defaults to `1024`) - kv_chunk_size (`int`, *optional*, defaults to `None`): if None, sqrt(key tokens) is used. + kv_chunk_size (`int`, *optional*, defaults to `None`): if None, sqrt(key_tokens) is used. + kv_chunk_size_min (`int`, *optional*, defaults to `None`): only considered when `kv_chunk_size is None`. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). + chunk_threshold_bytes (`int`, *optional*, defaults to `None`): if defined: only bother chunking if the self-attn matmul would allocate more bytes than this. whenever we can fit traditional attention into memory: we should prefer to do so, as the unchunked algorithm is faster. """ - processor = SubQuadraticCrossAttnProcessor(query_chunk_size, kv_chunk_size) + processor = SubQuadraticCrossAttnProcessor( + query_chunk_size=query_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min=kv_chunk_size_min, + chunk_threshold_bytes=chunk_threshold_bytes, + ) self.set_processor(processor) @@ -254,18 +263,26 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No class SubQuadraticCrossAttnProcessor: query_chunk_size: int kv_chunk_size: Optional[int] + kv_chunk_size_min: Optional[int] + chunk_threshold_bytes: Optional[int] def __init__( self, query_chunk_size = 1024, - kv_chunk_size: Optional[int] = None + kv_chunk_size: Optional[int] = None, + kv_chunk_size_min: Optional[int] = None, + chunk_threshold_bytes: Optional[int] = None, ): r""" Args: query_chunk_size (`int`, *optional*, defaults to `1024`) - kv_chunk_size (`int`, *optional*, defaults to `None`): if None, sqrt(key tokens) is used. + kv_chunk_size (`int`, *optional*, defaults to `None`): if None, sqrt(key_tokens) is used. + kv_chunk_size_min (`int`, *optional*, defaults to `None`): only considered when `kv_chunk_size is None`. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). + chunk_threshold_bytes (`int`, *optional*, defaults to `None`): if defined: only bother chunking if the self-attn matmul would allocate more bytes than this. whenever we can fit traditional attention into memory: we should prefer to do so, as the unchunked algorithm is faster. """ self.query_chunk_size = query_chunk_size self.kv_chunk_size = kv_chunk_size + self.kv_chunk_size_min = kv_chunk_size_min + self.chunk_threshold_bytes = chunk_threshold_bytes def __call__( self, @@ -277,6 +294,9 @@ def __call__( encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states assert attention_mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." + # I don't know what test case can be used to determine whether softmax is computed at sufficient bit-width, + # but sub-quadratic attention has a pretty bespoke softmax (defers computation of the denominator) so this needs some thought. + assert not attn.upcast_softmax or torch.finfo(hidden_states.dtype).bits >= 32, "upcast_softmax was requested, but is not implemented" query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) @@ -293,14 +313,29 @@ def __call__( query = query.float() key = key.float() - hidden_states = efficient_dot_product_attention( - query, - key, - value, - query_chunk_size=self.query_chunk_size, - key_chunk_size=self.kv_chunk_size, - use_checkpoint=attn.training, - ) + bytes_per_token = torch.finfo(query.dtype).bits//8 + batch_x_heads, q_tokens, _ = query.shape + _, k_tokens, _ = key.shape + qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens + + if self.chunk_threshold_bytes is None or qk_matmul_size_bytes > self.chunk_threshold_bytes: + hidden_states = efficient_dot_product_attention( + query, + key, + value, + query_chunk_size=self.query_chunk_size, + kv_chunk_size=self.kv_chunk_size, + kv_chunk_size_min=self.kv_chunk_size_min, + use_checkpoint=attn.training, + ) + else: + # the big matmul fits into our memory limit; compute via unchunked attention (it's faster) + attention_probs = attn.get_attention_scores( + query, + key, + ) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = hidden_states.to(dtype) hidden_states = hidden_states.unflatten(0, (-1, attn.heads)).transpose(1,2).flatten(start_dim=2) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index 84b073d4a06a..eab376fd6784 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -34,12 +34,15 @@ def _query_chunk_attention( query: Tensor, key: Tensor, value: Tensor, - key_chunk_size: Optional[int] = None, + kv_chunk_size: Optional[int] = None, + kv_chunk_size_min: Optional[int] = None, use_checkpoint = True, ): batch_x_heads, k_tokens, k_channels_per_head = key.shape _, _, v_channels_per_head = value.shape - key_chunk_size = min(key_chunk_size or int(math.sqrt(k_tokens)), k_tokens) + kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) + if kv_chunk_size_min is not None: + kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) scale = k_channels_per_head ** -0.5 def summarize_chunk( @@ -66,18 +69,18 @@ def chunk_scanner(chunk_idx: int) -> AttnChunk: key_chunk = dynamic_slice( key, (0, chunk_idx, 0), - (batch_x_heads, key_chunk_size, k_channels_per_head) + (batch_x_heads, kv_chunk_size, k_channels_per_head) ) value_chunk = dynamic_slice( value, (0, chunk_idx, 0), - (batch_x_heads, key_chunk_size, v_channels_per_head) + (batch_x_heads, kv_chunk_size, v_channels_per_head) ) return summarizer(query, key_chunk, value_chunk) chunks: List[AttnChunk] = [ - chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, key_chunk_size) + chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) ] acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) chunk_values, chunk_weights, chunk_max = acc_chunk @@ -100,7 +103,8 @@ def efficient_dot_product_attention( key: Tensor, value: Tensor, query_chunk_size=1024, - key_chunk_size: Optional[int] = None, + kv_chunk_size: Optional[int] = None, + kv_chunk_size_min: Optional[int] = None, use_checkpoint=True, ): """Computes efficient dot-product attention given query, key, and value. @@ -114,7 +118,8 @@ def efficient_dot_product_attention( value: values to be used in attention with shape of `[batch * num_heads, tokens, channels_per_head]`. query_chunk_size: int: query chunks size - key_chunk_size: int: key chunks size + kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens) + kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference) Returns: Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. @@ -132,7 +137,8 @@ def chunk_scanner(chunk_idx: int) -> Tensor: query_chunk, key, value, - key_chunk_size=key_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min=kv_chunk_size_min, use_checkpoint=use_checkpoint, ) From 48db71132fd0f7d2f308432558259b00665c305d Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Wed, 28 Dec 2022 00:32:05 +0000 Subject: [PATCH 27/35] fast path for when we're just attention-slicing (i.e. chunking query but not kv) --- .../models/sub_quadratic_attention.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index eab376fd6784..4ea4d6a19f2a 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -79,20 +79,35 @@ def chunk_scanner(chunk_idx: int) -> AttnChunk: return summarizer(query, key_chunk, value_chunk) + if k_tokens <= kv_chunk_size: + # fast-path for when there's only one chunk + # this is literally just attention slicing btw + attn_scores = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key.transpose(1,2), + alpha=scale, + beta=0, + ) + attn_probs = attn_scores.softmax(dim=-1) + del attn_scores + hidden_states_slice = torch.bmm(attn_probs, value) + return hidden_states_slice + chunks: List[AttnChunk] = [ chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) ] acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) chunk_values, chunk_weights, chunk_max = acc_chunk - global_max, _ = torch.max(chunk_max, 0, keepdim=True) - max_diffs = torch.exp(chunk_max - global_max) - chunk_values *= torch.unsqueeze(max_diffs, -1) - chunk_weights *= max_diffs + global_max, _ = torch.max(chunk_max, 0, keepdim=True) # this is just c[2].unsqueeze(0) + max_diffs = torch.exp(chunk_max - global_max) # this is ones_like(c[2].unsqueeze(0)) + chunk_values *= torch.unsqueeze(max_diffs, -1) # this is a no-op I suppose + chunk_weights *= max_diffs # this is a no-op I suppose - all_values = chunk_values.sum(dim=0) - all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) - return all_values / all_weights + all_values = chunk_values.sum(dim=0) # this is just c[0] + all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) # this is just c[1] + return all_values / all_weights # so just c[0] / c[1]; don't need c[2] class ScannedChunk(NamedTuple): chunk_idx: int From ef20fb9cbad60ed8fabe370a958590fd1d542790 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Wed, 28 Dec 2022 01:28:40 +0000 Subject: [PATCH 28/35] default kv_chunk_size was meant to be sqrt() of global key size, not of chunk key size. improve separation of concerns. --- .../models/sub_quadratic_attention.py | 132 ++++++++++-------- 1 file changed, 77 insertions(+), 55 deletions(-) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index 4ea4d6a19f2a..b43884a7636b 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -30,40 +30,43 @@ def __call__( value: Tensor, ) -> AttnChunk: ... +class ComputeQueryChunkAttn(Protocol): + @staticmethod + def __call__( + query: Tensor, + key: Tensor, + value: Tensor, + ) -> Tensor: ... + +def _summarize_chunk( + query: Tensor, + key: Tensor, + value: Tensor, + scale: float, +) -> AttnChunk: + attn_weights = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key.transpose(1,2), + alpha=scale, + beta=0, + ) + max_score, _ = torch.max(attn_weights, -1, keepdim=True) + max_score = max_score.detach() + exp_weights = torch.exp(attn_weights - max_score) + exp_values = torch.bmm(exp_weights, value) + max_score = max_score.squeeze(-1) + return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) + def _query_chunk_attention( query: Tensor, key: Tensor, value: Tensor, - kv_chunk_size: Optional[int] = None, - kv_chunk_size_min: Optional[int] = None, - use_checkpoint = True, -): + summarize_chunk: SummarizeChunk, + kv_chunk_size: int, +) -> Tensor: batch_x_heads, k_tokens, k_channels_per_head = key.shape _, _, v_channels_per_head = value.shape - kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) - if kv_chunk_size_min is not None: - kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) - scale = k_channels_per_head ** -0.5 - - def summarize_chunk( - query: Tensor, - key: Tensor, - value: Tensor, - ) -> AttnChunk: - attn_weights = torch.baddbmm( - torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), - query, - key.transpose(1,2), - alpha=scale, - beta=0, - ) - max_score, _ = torch.max(attn_weights, -1, keepdim=True) - max_score = max_score.detach() - exp_weights = torch.exp(attn_weights - max_score) - exp_values = torch.bmm(exp_weights, value) - max_score = max_score.squeeze(-1) - return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) - summarizer: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk def chunk_scanner(chunk_idx: int) -> AttnChunk: key_chunk = dynamic_slice( @@ -76,23 +79,7 @@ def chunk_scanner(chunk_idx: int) -> AttnChunk: (0, chunk_idx, 0), (batch_x_heads, kv_chunk_size, v_channels_per_head) ) - - return summarizer(query, key_chunk, value_chunk) - - if k_tokens <= kv_chunk_size: - # fast-path for when there's only one chunk - # this is literally just attention slicing btw - attn_scores = torch.baddbmm( - torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), - query, - key.transpose(1,2), - alpha=scale, - beta=0, - ) - attn_probs = attn_scores.softmax(dim=-1) - del attn_scores - hidden_states_slice = torch.bmm(attn_probs, value) - return hidden_states_slice + return summarize_chunk(query, key_chunk, value_chunk) chunks: List[AttnChunk] = [ chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) @@ -109,6 +96,25 @@ def chunk_scanner(chunk_idx: int) -> AttnChunk: all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) # this is just c[1] return all_values / all_weights # so just c[0] / c[1]; don't need c[2] +# TODO: refactor CrossAttention#get_attention_scores to share code with this +def _get_attention_scores_no_kv_chunking( + query: Tensor, + key: Tensor, + value: Tensor, + scale: float, +) -> Tensor: + attn_scores = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key.transpose(1,2), + alpha=scale, + beta=0, + ) + attn_probs = attn_scores.softmax(dim=-1) + del attn_scores + hidden_states_slice = torch.bmm(attn_probs, value) + return hidden_states_slice + class ScannedChunk(NamedTuple): chunk_idx: int attn_chunk: AttnChunk @@ -140,24 +146,40 @@ def efficient_dot_product_attention( Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. """ batch_x_heads, q_tokens, q_channels_per_head = query.shape + _, k_tokens, _ = key.shape + scale = q_channels_per_head ** -0.5 + + kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) + if kv_chunk_size_min is not None: + kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) - def chunk_scanner(chunk_idx: int) -> Tensor: - query_chunk = dynamic_slice( + def get_query_chunk(chunk_idx: int) -> Tensor: + return dynamic_slice( query, (0, chunk_idx, 0), (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) ) - - return _query_chunk_attention( - query_chunk, - key, - value, + + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) + summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk + compute_query_chunk_attn: ComputeQueryChunkAttn = partial( + _get_attention_scores_no_kv_chunking, + scale=scale + ) if k_tokens <= kv_chunk_size else ( + partial( + _query_chunk_attention, kv_chunk_size=kv_chunk_size, - kv_chunk_size_min=kv_chunk_size_min, - use_checkpoint=use_checkpoint, + summarize_chunk=summarize_chunk, ) + ) + # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, + # and pass slices to be mutated, instead of torch.cat()ing the returned slices res = torch.cat([ - chunk_scanner(i * query_chunk_size) for i in range(math.ceil(q_tokens / query_chunk_size)) + compute_query_chunk_attn( + query=get_query_chunk(i * query_chunk_size), + key=key, + value=value, + ) for i in range(math.ceil(q_tokens / query_chunk_size)) ], dim=1) return res From 69a8d2e6556fc43a4d1e7025279af6b5888d497d Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Wed, 28 Dec 2022 01:38:07 +0000 Subject: [PATCH 29/35] remove debug notes --- src/diffusers/models/sub_quadratic_attention.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index b43884a7636b..64c3c566653b 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -87,14 +87,14 @@ def chunk_scanner(chunk_idx: int) -> AttnChunk: acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) chunk_values, chunk_weights, chunk_max = acc_chunk - global_max, _ = torch.max(chunk_max, 0, keepdim=True) # this is just c[2].unsqueeze(0) - max_diffs = torch.exp(chunk_max - global_max) # this is ones_like(c[2].unsqueeze(0)) - chunk_values *= torch.unsqueeze(max_diffs, -1) # this is a no-op I suppose - chunk_weights *= max_diffs # this is a no-op I suppose + global_max, _ = torch.max(chunk_max, 0, keepdim=True) + max_diffs = torch.exp(chunk_max - global_max) + chunk_values *= torch.unsqueeze(max_diffs, -1) + chunk_weights *= max_diffs - all_values = chunk_values.sum(dim=0) # this is just c[0] - all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) # this is just c[1] - return all_values / all_weights # so just c[0] / c[1]; don't need c[2] + all_values = chunk_values.sum(dim=0) + all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) + return all_values / all_weights # TODO: refactor CrossAttention#get_attention_scores to share code with this def _get_attention_scores_no_kv_chunking( From db2593495c47ed9a5e962b135c31e5d16bbb4a20 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Wed, 28 Dec 2022 02:00:38 +0000 Subject: [PATCH 30/35] explain kv fast-path --- src/diffusers/models/sub_quadratic_attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index 64c3c566653b..ca39e7edd361 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -166,6 +166,7 @@ def get_query_chunk(chunk_idx: int) -> Tensor: _get_attention_scores_no_kv_chunking, scale=scale ) if k_tokens <= kv_chunk_size else ( + # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) partial( _query_chunk_attention, kv_chunk_size=kv_chunk_size, From 7aa8bac840b2bed41c5ee00bd4029629d89c8ed5 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Wed, 28 Dec 2022 02:00:48 +0000 Subject: [PATCH 31/35] add fast-path for "1 query chunk" --- src/diffusers/models/sub_quadratic_attention.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index ca39e7edd361..4f590702b7b0 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -173,6 +173,14 @@ def get_query_chunk(chunk_idx: int) -> Tensor: summarize_chunk=summarize_chunk, ) ) + + if q_tokens <= query_chunk_size: + # fast-path for when there's just 1 query chunk + return compute_query_chunk_attn( + query=query, + key=key, + value=value, + ) # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, # and pass slices to be mutated, instead of torch.cat()ing the returned slices From 59002c33af66d561f9c844ed4ae047b1e1d1910f Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Wed, 28 Dec 2022 02:16:04 +0000 Subject: [PATCH 32/35] move kv_chunk_size_min concern to callsite, since if caller knows final kv_chunk_size: they can notice when no chunking would happen at all, and use fast-path. note: there's a question of whether that concern belongs *inside* the algorithm. but it'd feel weird for chunked attention to have a no-chunking-at-all branch. --- src/diffusers/models/cross_attention.py | 12 +++++++++--- src/diffusers/models/sub_quadratic_attention.py | 6 ------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 51a60be0e78b..52f240452177 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -17,6 +17,7 @@ import torch import torch.nn.functional as F from torch import nn, Tensor +import math from ..utils.import_utils import is_xformers_available @@ -318,14 +319,19 @@ def __call__( _, k_tokens, _ = key.shape qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens - if self.chunk_threshold_bytes is None or qk_matmul_size_bytes > self.chunk_threshold_bytes: + kv_chunk_size = min(self.kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) + if self.kv_chunk_size_min is not None: + kv_chunk_size = max(kv_chunk_size, self.kv_chunk_size_min) + + uses_chunking = q_tokens > self.query_chunk_size or k_tokens > kv_chunk_size + + if uses_chunking and (self.chunk_threshold_bytes is None or qk_matmul_size_bytes > self.chunk_threshold_bytes): hidden_states = efficient_dot_product_attention( query, key, value, query_chunk_size=self.query_chunk_size, - kv_chunk_size=self.kv_chunk_size, - kv_chunk_size_min=self.kv_chunk_size_min, + kv_chunk_size=kv_chunk_size, use_checkpoint=attn.training, ) else: diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index 4f590702b7b0..a2786a3f7eff 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -125,7 +125,6 @@ def efficient_dot_product_attention( value: Tensor, query_chunk_size=1024, kv_chunk_size: Optional[int] = None, - kv_chunk_size_min: Optional[int] = None, use_checkpoint=True, ): """Computes efficient dot-product attention given query, key, and value. @@ -140,7 +139,6 @@ def efficient_dot_product_attention( `[batch * num_heads, tokens, channels_per_head]`. query_chunk_size: int: query chunks size kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens) - kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference) Returns: Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. @@ -149,10 +147,6 @@ def efficient_dot_product_attention( _, k_tokens, _ = key.shape scale = q_channels_per_head ** -0.5 - kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) - if kv_chunk_size_min is not None: - kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) - def get_query_chunk(chunk_idx: int) -> Tensor: return dynamic_slice( query, From a3152d86e356b618e68a3836e33cd4d16df689e2 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Wed, 28 Dec 2022 12:32:28 +0000 Subject: [PATCH 33/35] Revert "move kv_chunk_size_min concern to callsite (1c4f10748e31d18514ff1d5ed9fd9c67a278275b)" because equivalent fast-path for 1 query chunk, 1 kv chunk is already supported inside --- src/diffusers/models/cross_attention.py | 12 +++--------- src/diffusers/models/sub_quadratic_attention.py | 6 ++++++ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 52f240452177..51a60be0e78b 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -17,7 +17,6 @@ import torch import torch.nn.functional as F from torch import nn, Tensor -import math from ..utils.import_utils import is_xformers_available @@ -319,19 +318,14 @@ def __call__( _, k_tokens, _ = key.shape qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens - kv_chunk_size = min(self.kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) - if self.kv_chunk_size_min is not None: - kv_chunk_size = max(kv_chunk_size, self.kv_chunk_size_min) - - uses_chunking = q_tokens > self.query_chunk_size or k_tokens > kv_chunk_size - - if uses_chunking and (self.chunk_threshold_bytes is None or qk_matmul_size_bytes > self.chunk_threshold_bytes): + if self.chunk_threshold_bytes is None or qk_matmul_size_bytes > self.chunk_threshold_bytes: hidden_states = efficient_dot_product_attention( query, key, value, query_chunk_size=self.query_chunk_size, - kv_chunk_size=kv_chunk_size, + kv_chunk_size=self.kv_chunk_size, + kv_chunk_size_min=self.kv_chunk_size_min, use_checkpoint=attn.training, ) else: diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index a2786a3f7eff..4f590702b7b0 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -125,6 +125,7 @@ def efficient_dot_product_attention( value: Tensor, query_chunk_size=1024, kv_chunk_size: Optional[int] = None, + kv_chunk_size_min: Optional[int] = None, use_checkpoint=True, ): """Computes efficient dot-product attention given query, key, and value. @@ -139,6 +140,7 @@ def efficient_dot_product_attention( `[batch * num_heads, tokens, channels_per_head]`. query_chunk_size: int: query chunks size kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens) + kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference) Returns: Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. @@ -147,6 +149,10 @@ def efficient_dot_product_attention( _, k_tokens, _ = key.shape scale = q_channels_per_head ** -0.5 + kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) + if kv_chunk_size_min is not None: + kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) + def get_query_chunk(chunk_idx: int) -> Tensor: return dynamic_slice( query, From 0eafb95bcaff1bbf60ed6d79aa44a43803f643f4 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Wed, 28 Dec 2022 12:43:58 +0000 Subject: [PATCH 34/35] de-duplicate fast-path for "matmul < quota". we can just ask for everything in one chunk, to re-use an existing fast-path. --- src/diffusers/models/cross_attention.py | 35 +++++++++++++------------ 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 51a60be0e78b..89e0c4d545ea 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -318,23 +318,24 @@ def __call__( _, k_tokens, _ = key.shape qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens - if self.chunk_threshold_bytes is None or qk_matmul_size_bytes > self.chunk_threshold_bytes: - hidden_states = efficient_dot_product_attention( - query, - key, - value, - query_chunk_size=self.query_chunk_size, - kv_chunk_size=self.kv_chunk_size, - kv_chunk_size_min=self.kv_chunk_size_min, - use_checkpoint=attn.training, - ) - else: - # the big matmul fits into our memory limit; compute via unchunked attention (it's faster) - attention_probs = attn.get_attention_scores( - query, - key, - ) - hidden_states = torch.bmm(attention_probs, value) + query_chunk_size = self.query_chunk_size + kv_chunk_size = self.kv_chunk_size + + if self.chunk_threshold_bytes is not None and qk_matmul_size_bytes <= self.chunk_threshold_bytes: + # the big matmul fits into our memory limit; do everything in 1 chunk, + # i.e. send it down the unchunked fast-path + query_chunk_size = q_tokens + kv_chunk_size = k_tokens + + hidden_states = efficient_dot_product_attention( + query, + key, + value, + query_chunk_size=query_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min=self.kv_chunk_size_min, + use_checkpoint=attn.training, + ) hidden_states = hidden_states.to(dtype) From 9dc68226a37fcdd0b387a1a20bb69e0d4b0dbb02 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Fri, 30 Dec 2022 17:13:40 +0000 Subject: [PATCH 35/35] pre-transpose key, rather than transposing it then undoing the transpose during the matmul --- src/diffusers/models/cross_attention.py | 9 ++--- .../models/sub_quadratic_attention.py | 36 +++++++++---------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 89e0c4d545ea..1e8453399dff 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -303,7 +303,8 @@ def __call__( value = attn.to_v(encoder_hidden_states) query = query.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1) - key = key.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1) + key_t = key.transpose(1,2).unflatten(1, (attn.heads, -1)).flatten(end_dim=1) + del key value = value.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1) dtype = query.dtype @@ -311,11 +312,11 @@ def __call__( # TODO: do we need to support upcast_softmax too? SD 2.1 seems to work without it if attn.upcast_attention: query = query.float() - key = key.float() + key_t = key_t.float() bytes_per_token = torch.finfo(query.dtype).bits//8 batch_x_heads, q_tokens, _ = query.shape - _, k_tokens, _ = key.shape + _, _, k_tokens = key_t.shape qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens query_chunk_size = self.query_chunk_size @@ -329,7 +330,7 @@ def __call__( hidden_states = efficient_dot_product_attention( query, - key, + key_t, value, query_chunk_size=query_chunk_size, kv_chunk_size=kv_chunk_size, diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py index 4f590702b7b0..a2e8aea513f5 100644 --- a/src/diffusers/models/sub_quadratic_attention.py +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -26,7 +26,7 @@ class SummarizeChunk(Protocol): @staticmethod def __call__( query: Tensor, - key: Tensor, + key_t: Tensor, value: Tensor, ) -> AttnChunk: ... @@ -34,20 +34,20 @@ class ComputeQueryChunkAttn(Protocol): @staticmethod def __call__( query: Tensor, - key: Tensor, + key_t: Tensor, value: Tensor, ) -> Tensor: ... def _summarize_chunk( query: Tensor, - key: Tensor, + key_t: Tensor, value: Tensor, scale: float, ) -> AttnChunk: attn_weights = torch.baddbmm( torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), query, - key.transpose(1,2), + key_t, alpha=scale, beta=0, ) @@ -60,19 +60,19 @@ def _summarize_chunk( def _query_chunk_attention( query: Tensor, - key: Tensor, + key_t: Tensor, value: Tensor, summarize_chunk: SummarizeChunk, kv_chunk_size: int, ) -> Tensor: - batch_x_heads, k_tokens, k_channels_per_head = key.shape + batch_x_heads, k_channels_per_head, k_tokens = key_t.shape _, _, v_channels_per_head = value.shape def chunk_scanner(chunk_idx: int) -> AttnChunk: key_chunk = dynamic_slice( - key, - (0, chunk_idx, 0), - (batch_x_heads, kv_chunk_size, k_channels_per_head) + key_t, + (0, 0, chunk_idx), + (batch_x_heads, k_channels_per_head, kv_chunk_size) ) value_chunk = dynamic_slice( value, @@ -99,14 +99,14 @@ def chunk_scanner(chunk_idx: int) -> AttnChunk: # TODO: refactor CrossAttention#get_attention_scores to share code with this def _get_attention_scores_no_kv_chunking( query: Tensor, - key: Tensor, + key_t: Tensor, value: Tensor, scale: float, ) -> Tensor: attn_scores = torch.baddbmm( torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), query, - key.transpose(1,2), + key_t, alpha=scale, beta=0, ) @@ -121,21 +121,21 @@ class ScannedChunk(NamedTuple): def efficient_dot_product_attention( query: Tensor, - key: Tensor, + key_t: Tensor, value: Tensor, query_chunk_size=1024, kv_chunk_size: Optional[int] = None, kv_chunk_size_min: Optional[int] = None, use_checkpoint=True, ): - """Computes efficient dot-product attention given query, key, and value. + """Computes efficient dot-product attention given query, transposed key, and value. This is efficient version of attention presented in https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. Args: query: queries for calculating attention with shape of `[batch * num_heads, tokens, channels_per_head]`. - key: keys for calculating attention with shape of - `[batch * num_heads, tokens, channels_per_head]`. + key_t: keys for calculating attention with shape of + `[batch * num_heads, channels_per_head, tokens]`. value: values to be used in attention with shape of `[batch * num_heads, tokens, channels_per_head]`. query_chunk_size: int: query chunks size @@ -146,7 +146,7 @@ def efficient_dot_product_attention( Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. """ batch_x_heads, q_tokens, q_channels_per_head = query.shape - _, k_tokens, _ = key.shape + _, _, k_tokens = key_t.shape scale = q_channels_per_head ** -0.5 kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) @@ -178,7 +178,7 @@ def get_query_chunk(chunk_idx: int) -> Tensor: # fast-path for when there's just 1 query chunk return compute_query_chunk_attn( query=query, - key=key, + key_t=key_t, value=value, ) @@ -187,7 +187,7 @@ def get_query_chunk(chunk_idx: int) -> Tensor: res = torch.cat([ compute_query_chunk_attn( query=get_query_chunk(i * query_chunk_size), - key=key, + key_t=key_t, value=value, ) for i in range(math.ceil(q_tokens / query_chunk_size)) ], dim=1)