Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added medium and medium.en models for TensorRT-LLM backend #31

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

colinator
Copy link

@colinator colinator commented Feb 21, 2024

Seems to work for "medium" and "medium.en" models now, for tensorrt-llm backend.
Fixes #30

@shashikg
Copy link
Owner

Hi @colinator can you run some WER checks on medium and medium.en models for TensorRT-LLM backend? According to TensorRT-LLM repo, they only support large model.

You can use these to run the tests:
Prepare the env using this script: https://github.com/shashikg/WhisperS2T/blob/main/prepare_benchmark_env.sh
Then use this: https://github.com/shashikg/WhisperS2T/blob/main/scripts/benchmark_whisper_s2t.py

@colinator
Copy link
Author

Ahoy there. I couldn't find where in tensorrt-llm it only was compatible with large-v2. Maybe in the 'builder'? But that's just the builder. Your code seemed to work. The benchmark doesn't calculate WER right? The transcriptions seem plausible though:

data/KINCAID46/audio/9.mp3 For over a decade, on streets, and in bars, and in living rooms, from Rio to Reykjavik, and everywhere in between, the debate has raged. Lionel Messi, or Christiano Ronaldo. Who is the greatest player in the world? The truth is, we've never seen a rivalry like this. Not in football, not in any of our major sports. But that debate might soon be over, and it might end without Messi or Ronaldo getting the trophy they so glaringly lack, the World Cup. To understand why it matters that two individuals haven't won a team trophy that's only up for grabs every 4 years, you have to know just how dominate Messi and Ronaldo have been. The winner is, Christiano Ronaldo. From 2008 through 2017, the Ballon d'Or, basically the sports most valuable player award, has gone to either Ronaldo or Messi. Five for Ronaldo, five for Messi, none for anyone else. It's not just Ballon d'Or's for Messi and Ronaldo. Messi is the all time leading scorer in La Liga, in the Spanish super ...

You want me to attach the transcription csv files?

@colinator
Copy link
Author

Do have a script that performs WER calculation from the csv outputs? I see your WER function, but am not totally clear on any pre-processing (lowercasing, etc) you do when you calculate it...

@colinator
Copy link
Author

colinator commented Feb 22, 2024

Well, here are the outputs. I'm on an RTX 3080, so slower than your results, and batch size is 16, because mem$.

results.zip

@shashikg
Copy link
Owner

Do have a script that performs WER calculation from the csv outputs? I see your WER function, but am not totally clear on any pre-processing (lowercasing, etc) you do when you calculate it...

Hey yes I normalize the text and then performs lowercasing as well. Here: https://github.com/shashikg/WhisperS2T/blob/main/tools/text_normalizer.py#L75

Then run this evaluate function on normalized texts: https://github.com/shashikg/WhisperS2T/blob/main/tools/metrics.py#L68

BTW, I quickly checked the outputs txt files, output looks good to me.

@shashikg shashikg added the enhancement New feature or request label Feb 22, 2024
@shashikg shashikg changed the title added medium and medium.en models, hopefully Feb 22, 2024
@shashikg
Copy link
Owner

shashikg commented Mar 1, 2024

Hi @colinator any update?

@colinator
Copy link
Author

colinator commented Mar 4, 2024

I got this, for medium and medium.en. Card is rtx3080, if that matters...

Why is medium.en so much worse?

results/WhisperS2T-TensorRT-LLM-bs_16_medium
                Dataset        Time
0         KINCAID46 WAV   66.376619
1         KINCAID46 MP3   66.754254
2  MultiLingualLongform  158.262608
KINCAID46_WAV.tsv {'WER': 9.11, 'IER': 1.56, 'DER': 4.49, 'SER': 3.05, '5-GramInsertions': 35}
KINCAID46_MP3.tsv {'WER': 9.52, 'IER': 1.58, 'DER': 4.86, 'SER': 3.08, '5-GramInsertions': 31}
MultiLingualLongform.tsv {'WER': 9.4, 'IER': 3.03, 'DER': 3.19, 'SER': 3.18, '5-GramInsertions': 105}

results/WhisperS2T-TensorRT-LLM-bs_16_medium.en
                Dataset        Time
0         KINCAID46 WAV   64.459512
1         KINCAID46 MP3   69.381940
2  MultiLingualLongform  212.189986
KINCAID46_WAV.tsv {'WER': 13.34, 'IER': 1.36, 'DER': 9.09, 'SER': 2.89, '5-GramInsertions': 20}
KINCAID46_MP3.tsv {'WER': 12.14, 'IER': 1.33, 'DER': 7.93, 'SER': 2.88, '5-GramInsertions': 20}
MultiLingualLongform.tsv {'WER': 57.83, 'IER': 6.35, 'DER': 7.9, 'SER': 43.58, '5-GramInsertions': 7140}
@colinator
Copy link
Author

colinator commented Mar 4, 2024

Oh, this is the script that prints it out - might be useful for some bigger pipeline. I'll just paste it here - not sure if I should add it to this PR yet...

# Run like this, from base:
# python -m scripts.print_wer_results --results_path results

import os
import argparse
from typing import Optional
from tools.metrics import evaluate
from tools.text_normalizer import TextNormalizer
import pandas as pd

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--results_path', default="results", type=str)
    args = parser.parse_args()
    return args

def results_from_tsv(path_to_tsv: str, normalize: Optional[TextNormalizer]):
    df = pd.read_csv(path_to_tsv, sep="\t")
    references = df['raw_text'].to_list()
    hypotheses = df['pred_text'].to_list()
    if normalize:
        references = [normalize(t) for t in references]
        hypotheses = [normalize(t) for t in hypotheses]
    scores = evaluate(references, hypotheses)
    return scores

def print_results_in_dir(path_to_dir: str, filenames: list[str], normalize: Optional[TextNormalizer]):
    print()
    print(path_to_dir)
    print(pd.read_csv(path_to_dir + "/infer_time.tsv", sep="\t"))
    for tsv in filenames:
        print(tsv, results_from_tsv(path_to_dir + "/" + tsv, normalize))

if __name__ == "__main__":
    args = parse_arguments()
    rd = args.results_path
    results_directories = [rd + '/' + d for d in os.listdir(rd) if os.path.isdir(os.path.join(rd, d))]
    filenames = ["KINCAID46_WAV.tsv", "KINCAID46_MP3.tsv", "MultiLingualLongform.tsv"]
    normalizer = TextNormalizer()
    for rd in results_directories:
        print_results_in_dir(rd, filenames, normalize=normalizer)

@colinator
Copy link
Author

@shashikg ^^^

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
2 participants