Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

depricated flag for flash attention 2 with huggingface backend #39

Open
BBC-Esq opened this issue Feb 25, 2024 · 1 comment
Open

depricated flag for flash attention 2 with huggingface backend #39

BBC-Esq opened this issue Feb 25, 2024 · 1 comment
Assignees

Comments

@BBC-Esq
Copy link

BBC-Esq commented Feb 25, 2024

Hello, just FYI in case you didn't know, apparently Huggingface changed the flag/parameter or what not when trying to specify flash attention 2. Here's the message I got:

The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.

And here's the script I am testing:

import whisper_s2t

model_kwargs = {
    'compute_type': 'float16',
    'asr_options': {
    "beam_size": 5,
    "without_timestamps": True,
    "return_scores": False,
    "return_no_speech_prob": False,
    "use_flash_attention": True,
    "use_better_transformer": False,
},
    'model_identifier': "small",
    'backend': 'HuggingFace',
}

model = whisper_s2t.load_model(**model_kwargs)

files = ['test_audio_flac.flac']
lang_codes = ['en']
tasks = ['transcribe']
initial_prompts = [None]

out = model.transcribe_with_vad(files,
                                lang_codes=lang_codes,
                                tasks=tasks,
                                initial_prompts=initial_prompts,
                                batch_size=20)

transcription = " ".join([_['text'] for _ in out[0]]).strip()

with open('transcription.txt', 'w') as f:
    f.write(transcription)

BTW, I tried using the newer attn_implementation="flash_attention_2" with Bark and COULD NOT get it to work...yet with your program that uses the old use_flash_attention_2=Trueit works. I don't know if it was my script or the different flags....but just be aware in case.

@shashikg shashikg self-assigned this Mar 1, 2024
@shashikg
Copy link
Owner

shashikg commented Mar 1, 2024

Thanks will update this in next release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants