Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Efficiency] The llama model with flash attention is slower than that without flash attention #26990

Closed
2 of 4 tasks
KexinFeng opened this issue Oct 21, 2023 · 7 comments
Closed
2 of 4 tasks

Comments

@KexinFeng
Copy link

KexinFeng commented Oct 21, 2023

System Info

The test ran with this fix applied: #26984

- `transformers` version: 4.34.0
- Platform: Linux-5.15.0-1045-aws-x86_64-with-glibc2.31
- Python version: 3.9.18
- Huggingface_hub version: 0.17.3
- Safetensors version: 0.4.0
- Accelerate version: 0.23.0
- Accelerate config:    not found
- PyTorch version (GPU?): 2.0.1+cu118 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

Who can help?

@ArthurZucker and @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The model loading:

def get_model_tokenizer(model_id, flash_attn=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_id_or_path = "huggyllama/llama-7b"
    model = AutoModelForCausalLM.from_pretrained(
        model_id_or_path, device_map='auto' if device.type == 'cuda' else 'cpu',
        use_flash_attention_2=flash_attn)
    lm_block = HuggingfaceBlock(model)
    tokenizer = AutoTokenizer.from_pretrained(model_id_or_path,
                                              padding_side='left')
    tokenizer.pad_token = "[PAD]"

    return lm_block, tokenizer

Input_length = 760
batch_size = 13
Max_gen_token = [300, 100, 50, 20]

When `flash_attn==True':

token_latency: [18.3 ms/token, 20.7 ms/token, 26.4 ms/token , 44.1 ms/token ]

When 'flash_attn' == False':

token_latency: [14.1 ms/token, 17.8 ms/token, 24.3 ms/token , 44.2 ms/token ]

Expected behavior

Flash attention should accelerate the inference.

@younesbelkada
Copy link
Contributor

Hi @KexinFeng
Thanks for the issue, usually the speedup is quite considerable for a large sequence length. Can you try out your experiment with for example seq_len=2048? Also make sure to use a batch size that is divisble by 2

@KexinFeng
Copy link
Author

KexinFeng commented Oct 25, 2023

@younesbelkada Thanks for pointing out the sequence length. Indeed, at seq_len=3500, the flash_attention gains speed up. However, it is not significant compared to non-flash attention.

Input_length = 3500
batch_size = 4
Max_gen_token = [300, 100, 50, 20]

Corresponding to each max_gen_token:

flash_attn=True

token_latency = 33.9 ms/token, 39.7 ms/token, 49.3 ms/token, 78.8 ms/token 

flash_attn = False

token_latency = 28.8 ms/token, 39.9 ms/token, 57.3 ms/token, 110 ms/token 

I thought the expected behaviour should be that the flash_attention should be purely faster than non-flash attention. What factor contributed the overhead to the flash_attention compared to non-flash attention?

From the benchmark above, it seems that as gen_token gets longer, the flash_attention is slower. This means that this overhead contributed to the flash_attention only is induced at every decoding step. So the speed up gained at the prefill step is gradually overridden by such overhead as decoding steps proceed.

@ArthurZucker
Copy link
Collaborator

If you are passing the attention mask to the model, I think the pad and unpad operation add a non negligeable overhead

@KexinFeng
Copy link
Author

@ArthurZucker Yes, indeed, I fed the attention mask into the model, with a lot of 0 entries (corresponding to the PAD token). Thanks for this insight. But is there any plan of removing this overhead? It seems to me that flash_attention algorithm in principle doesn't necesarily require the pad and unpad operation. Currently, it looks that the advantage of flash_attention over non flash one is not clear.

@younesbelkada
Copy link
Contributor

Hi @KexinFeng
As stated by @ArthurZucker adding padd tokens in the sequence length adds a considerable overhead in FA modules. The expected speedups and best scenarios on when to use FA-2 are clearly stated in this section of the docs: https://huggingface.co/docs/transformers/perf_infer_gpu_one#expected-speedups

@KexinFeng
Copy link
Author

KexinFeng commented Nov 1, 2023

@younesbelkada Thank you for pointing this document to me! Indeed, the issue I brought up here has been documented there. What's more, the document also shows the data of how the speedup depends on prompt max length, which is also very helpful.

However regarding the solution proposed in the document,

To overcome this, one should use Flash Attention without padding tokens in the sequence for training (e.g., by packing a dataset, i.e., concatenating sequences until reaching the maximum sequence length. An example is provided here.

it doesn't seem to be applicable on model inference and serving scenario, which is where this issue originates. Especially with dynamically batching inference, this packing of dataset doesn't work. It seems to me that padding is unavoidable in the inference scenarios. A possible way to avoid it is to switch the flash attention kernal to something like var_len_single_query_attention (already exists in the flash attention repo), where the input is flattened into 1D tensor.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Dec 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants