Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot run inference using API for video chat with Stable_LM #70

Open
plischwe opened this issue Oct 10, 2023 · 1 comment
Open

Cannot run inference using API for video chat with Stable_LM #70

plischwe opened this issue Oct 10, 2023 · 1 comment

Comments

@plischwe
Copy link

I have successfully launched the "Ask Anything with StableLM" model using public url from gradio. But I'm running into an error when using an example request from gradio. The file I am running is here:
from gradio_client import Client

client = Client("https://3774b146370bec32fe.gradio.live/")
result = client.predict(
"https://github.com/gradio-app/gradio/raw/main/test/test_files/video_sample.mp4", # str (filepath on your computer (or URL) of file) in 'Input Video' Video component
"Howdy!", # str in 'User Prompt (Optional, Enter with commas)' Textbox component
fn_index=4
)
print(result)

The error that I am receiving is seen here:
Traceback (most recent call last):
File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/gradio/queueing.py", line 388, in call_prediction
output = await route_utils.call_process_api(
File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/gradio/route_utils.py", line 217, in call_process_api
output = await app.get_blocks().process_api(
File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/gradio/blocks.py", line 1554, in process_api
result = await self.call_function(
File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/gradio/blocks.py", line 1192, in call_function
prediction = await anyio.to_thread.run_sync(
File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/anyio/to_thread.py", line 33, in run_sync
return await get_asynclib().run_sync_in_worker_thread(
File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/anyio/_backends/_asyncio.py", line 877, in run_sync_in_worker_thread
return await future
File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/anyio/_backends/_asyncio.py", line 807, in run
result = context.run(func, *args)
File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/gradio/utils.py", line 659, in wrapper
response = f(*args, **kwargs)
File "/home/plischwe/Ask-Anything/video_chat_with_StableLM/app.py", line 77, in inference
caption, tag_predict = model.generate(image,tag_input = input_tag_list,max_length = 50, return_tag_predict = True)
File "/home/plischwe/Ask-Anything/video_chat_with_StableLM/models/tag2text.py", line 200, in generate
outputs = self.text_decoder.generate(input_ids=input_ids,
File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/transformers/generation/utils.py", line 1685, in generate
return self.beam_search(
File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/transformers/generation/utils.py", line 3024, in beam_search
outputs = self(
File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/plischwe/Ask-Anything/video_chat_with_StableLM/models/med.py", line 962, in forward
outputs = self.bert(
File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/plischwe/Ask-Anything/video_chat_with_StableLM/models/med.py", line 858, in forward
encoder_outputs = self.encoder(
File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/plischwe/Ask-Anything/video_chat_with_StableLM/models/med.py", line 522, in forward
layer_outputs = layer_module(
File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/plischwe/Ask-Anything/video_chat_with_StableLM/models/med.py", line 438, in forward
cross_attention_outputs = self.crossattention(
File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/plischwe/Ask-Anything/video_chat_with_StableLM/models/med.py", line 333, in forward
self_outputs = self.self(
File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/plischwe/Ask-Anything/video_chat_with_StableLM/models/med.py", line 234, in forward
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
RuntimeError: The size of tensor a (24) must match the size of tensor b (9) at non-singleton dimension 0
File "/home/plischwe/Ask-Anything/video_chat_with_StableLM/models/med.py", line 234, in forward

Any input would be greatly appreciated - thanks.

@yinanhe
Copy link
Member

yinanhe commented Jan 24, 2024

To solve this problem, you can refer to salesforce/BLIP#142 (comment)

The following could be a possible solution:

change attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) in line 234 of /home/plischwe/Ask-Anything/video_chat_with_StableLM/models/med.py to

if key_layer.shape[0] > query_layer.shape[0]:
    key_layer = key_layer[:query_layer.shape[0], :, :, :]
    attention_mask = attention_mask[:query_layer.shape[0], :, :]
    value_layer = value_layer[:query_layer.shape[0], :, :, :]
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants