-
Notifications
You must be signed in to change notification settings - Fork 26
/
attention.py
288 lines (244 loc) · 8.91 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Common building blocks for Attention layer.
import math
from typing import Optional, Tuple
import torch
from torch import nn
import torch.nn.functional as F
import ai_edge_torch.generative.layers.builder as builder
from ai_edge_torch.generative.layers.kv_cache import KVCache
import ai_edge_torch.generative.layers.model_config as cfg
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
def scaled_dot_product_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
head_size: int,
mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
):
"""Scaled dot product attention.
Args:
q (torch.Tensor): Query tensor, with shape [B, T, N, H].
k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
head_size (int): head dimension.
mask (torch.Tensor): the optional mask tensor.
Returns:
The output tensor of scaled_dot_product_attention.
"""
if scale is None:
scale = 1.0 / math.sqrt(head_size)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if q.size() != k.size():
# Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
y = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=0.0,
is_causal=mask is None,
scale=scale,
)
return y.transpose(1, 2)
def scaled_dot_product_attention_with_hlfb(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
head_size: int,
mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
):
"""Scaled dot product attention with high-level function boundary enabled.
Args:
q (torch.Tensor): Query tensor, with shape [B, T, N, H].
k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
head_size (int): head dimension.
mask (torch.Tensor): the optional mask tensor.
Returns:
The output tensor of scaled_dot_product_attention.
"""
if scale is None:
scale = 1.0 / math.sqrt(head_size)
builder = StableHLOCompositeBuilder(
name="odml.scaled_dot_product_attention", attr={"scale": scale}
)
q, k, v, mask = builder.mark_inputs(q, k, v, mask)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if q.size() != k.size():
# Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
y = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=0.0,
is_causal=mask is None,
scale=scale,
)
result = y.transpose(1, 2)
result = builder.mark_outputs(result)
return result
class TransformerBlock(nn.Module):
def __init__(self, config: cfg.ModelConfig) -> None:
"""Initialize an instance of the TransformerBlock.
Args:
config (cfg.ModelConfig): the configuration object
for this transformer block.
"""
super().__init__()
self.pre_atten_norm = builder.build_norm(
config.embedding_dim, config.pre_attention_norm_config
)
self.atten_func = CausalSelfAttention(
config.embedding_dim,
config.attn_config,
config.kv_cache_max,
config.enable_hlfb,
)
self.pre_ff_norm = builder.build_norm(
config.embedding_dim, config.pre_ff_norm_config
)
self.ff = builder.build_ff(config.embedding_dim, config.ff_config)
self.config = config
def forward(
self,
x: torch.Tensor,
rope: Tuple[torch.Tensor, torch.Tensor],
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward function of the TransformerBlock.
Args:
x (torch.Tensor): the input tensor.
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
mask (torch.Tensor): the optional mask tensor.
input_pos (torch.Tensor): the optional input position tensor.
Returns:
output activation from this transformer block.
"""
if self.config.parallel_residual:
x_norm = self.pre_atten_norm(x)
attn_out = self.atten_func(x_norm, rope, mask, input_pos)
ff_out = self.ff(x_norm)
output = x + attn_out + ff_out
else:
x_norm = self.pre_atten_norm(x)
attn_out = self.atten_func(x_norm, rope, mask, input_pos)
x = x + attn_out
x_norm = self.pre_ff_norm(x)
output = x + self.ff(x_norm)
return output
# CausalSelfAttention which can support MHQ, MQA or GQA.
class CausalSelfAttention(nn.Module):
def __init__(
self,
dim: int,
config: cfg.AttentionConfig,
kv_cache_max: int,
enable_hlfb: bool,
) -> None:
"""Initialize an instance of CausalSelfAttention.
Args:
dim (int): causal attention's input/output dimmension.
config (cfg.AttentionConfig): attention specific configurations.
kv_cache_max (int): determines the size of the KV Cache buffer, if enabled.
enable_hlfb (bool): whether hlfb is enabled or not.
"""
super().__init__()
self.head_dim = dim // config.num_heads
shape = (config.num_heads + 2 * config.num_query_groups) * self.head_dim
# Key, query, value projections for all heads.
self.qkv_projection = nn.Linear(dim, shape, bias=config.qkv_use_bias)
self.output_projection = nn.Linear(dim, dim, bias=config.output_proj_use_bias)
self.config = config
self.kv_cache = None
# Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
# Now only supports batch_size of 1.
# TODO(haoliang): support batch_size greater than 1.
if config.enable_kv_cache:
self.kv_cache = KVCache(
1,
kv_cache_max,
config.num_query_groups,
self.head_dim,
enable_hlfb,
)
if enable_hlfb:
self.sdpa_func = scaled_dot_product_attention_with_hlfb
else:
self.sdpa_func = scaled_dot_product_attention
def forward(
self,
x: torch.Tensor,
rope: Tuple[torch.Tensor, torch.Tensor],
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward function of the CausalSelfAttention layer.
Args:
x (torch.Tensor): the input tensor.
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
mask (torch.Tensor): the optional mask tensor.
input_pos (torch.Tensor): the optional input position tensor.
Returns:
output activation from this self attention layer.
"""
# Batch size, sequence length, embedding dimensionality.
B, T, E = x.size()
assert B == 1, "Currently only batch_size = 1 is supported."
qkv = self.qkv_projection(x)
# Assemble into a number of query groups to support MHA, MQA and GQA.
q_per_kv = self.config.num_heads // self.config.num_query_groups
total_qkv = q_per_kv + 2 # Each group has >=1 queries, 1 key, and 1 value.
qkv = qkv.view(
B, T, self.config.num_query_groups, total_qkv, self.head_dim
) # (B, T, num_query_groups, total_qkv, head_dim)
# Split batched computation into three.
q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2)
q = q.reshape(B, T, -1, self.head_dim)
k = k.reshape(B, T, -1, self.head_dim)
v = v.reshape(B, T, -1, self.head_dim)
# Compute rotary positional embedding for query and key.
n_elem = int(self.config.rotary_percentage * self.head_dim)
cos, sin = rope
q_roped = rotary_pos_emb.apply_rope(
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
)
k_roped = rotary_pos_emb.apply_rope(
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
)
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
if self.kv_cache is not None:
# TODO(haoliang): Handle when execeeding max sequence length.
k, v = self.kv_cache.update_cache(input_pos, k, v)
y = self.sdpa_func(q, k, v, self.head_dim, mask=mask)
y = y.reshape(B, T, E)
# Compute the output projection.
y = self.output_projection(y)
return y