Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix CustomPromptPipeline #3275

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

fix CustomPromptPipeline #3275

wants to merge 2 commits into from

Conversation

idisuu
Copy link

@idisuu idisuu commented Jun 1, 2023

@register_datapipeline
class CustomPromptPipeline(BasePipeline):
"""
Tokenizes prompts, unless they are already tokenized, and truncates them to `max_prompt_length` from the right
"""
def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer: PreTrainedTokenizer):
super().__init__()

Request to add **kwargs to CustomPromptPipeline init method to be flexible with trlx library version

The following error occurred while trying RLHF using trainer_rl.py
image

As a result of analyzing the cause, a parameter called add_special_tokens was added to the get_pipeline function in trlx.py of the main version of the trlx library.
image
https://github.com/CarperAI/trlx/blob/355c9741f2e606de796f5c6f9b682f7dd00f97c5/trlx/trlx.py#L122-L125C6

However, the parameter does not yet exist in the CustomPromptPipeline class of the main branch of OA.

There is an add_speical_token parameter in the get_pipeline method of trlx.py in the main branch of trlx, but that parameter does not exist in the v.0.6.0 version.

So I think we need to add **kwarg to prevent errors in different versions.

The add_speical_tokens parameter is true only when the architecture of the model is a seq2seq model, but previous OA versions always set it to False, so adding it as **kwarg seems to be okay for now.

@sanagno
Copy link
Collaborator

sanagno commented Jun 1, 2023

Thanks for the update! Are there any other changes required for the newer trlx version? In general. I am a bit sceptical towards "blindly" updating to newer version as I have not checked what other changes have been made. If you can check that no other breaking changes have been made and that there are advantages to the new version, we can update the trlx requirement.

@idisuu
Copy link
Author

idisuu commented Jun 5, 2023

I once discovered that the trlx version is not specified.

"trlx @ git+https://github.com/CarperAI/trlx.git",

So the problem seems to be that the current OA version is compatible with the previous trlx version(v0.5.0), but an incompatible problem occurred as the trlx was updated to the latest version(main).

So I specified the version in pyproject.toml.

The trlx tagv0.6.0 version had a problem with not being installed through the pip install -e command, so it was specified as the v.0.5.0 version.

** The reason why the trx tag v0.6.0 version is not installed seems to be an internal problem with the trlx library.
image

@idisuu
Copy link
Author

idisuu commented Jun 5, 2023

When using trlx tag v0.5.0, I found that an error occurred when loading the llama model in reinforcement learning.

Therefore, using v0.5.0 seems undesirable.
trlx_0 5 0 error

image

However, I thkink OA need version management that is compatible with trlx, but I don't know how to do it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3 participants