Skip to content

Commit

Permalink
Added: click support for additional_context (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
signebedi committed Apr 4, 2023
1 parent 5891bb7 commit ec4c746
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
9 changes: 5 additions & 4 deletions gptty/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,20 +137,21 @@ async def chat_async_wrapper(config_path:str, verbose:bool):
# @click.option('--log', '-l', is_flag=True, callback=print_version,
# expose_value=False, is_eager=True, help="Show question log.")
@click.option('--config_path', '-c', default=os.path.join(os.getcwd(),'gptty.ini'), help="Path to config file.")
@click.option('--additional_context', '-a', default="", help="Pass more context to your questions.")
@click.option('--question', '-q', multiple=True, help='Repeatable list of questions.')
@click.option('--tag', '-t', default="", help='Tag to categorize your query. [optional]')
@click.option('--verbose', '-v', is_flag=True, help="Show debug data.")
@click.option('--json', '-j', is_flag=True, help="Return query as JSON object.")
@click.option('--quiet', is_flag=True, help="Don't write to stdout.")
def query(config_path:str, question:str, tag:str, verbose:bool, json:bool, quiet:bool):
def query(config_path:str, additional_context:str, question:str, tag:str, verbose:bool, json:bool, quiet:bool):
"""
Submit a gptty query
"""

asyncio.run(query_async_wrapper(config_path, question, tag, verbose, json, quiet))
asyncio.run(query_async_wrapper(config_path, question, tag, additional_context, verbose, json, quiet))


async def query_async_wrapper(config_path:str, question:str, tag:str, verbose:bool, json:bool, quiet:bool):
async def query_async_wrapper(config_path:str, question:str, tag:str, additional_context:str, verbose:bool, json:bool, quiet:bool):
# load the app configs
configs = get_config_data(config_file=config_path)

Expand All @@ -161,7 +162,7 @@ async def query_async_wrapper(config_path:str, question:str, tag:str, verbose:bo
# create the output file if it doesn't exist
with open (configs['output_file'], 'a'): pass

await run_query(questions=question, tag=tag, configs=configs, config_path=config_path, verbose=verbose, return_json=json, quiet=quiet)
await run_query(questions=question, tag=tag, configs=configs, additional_context=additional_context, config_path=config_path, verbose=verbose, return_json=json, quiet=quiet)


@click.command()
Expand Down
6 changes: 4 additions & 2 deletions gptty/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def get_context(tag: str,
question: str = None,
debug: bool = False):

# additional_context = additional_context.replace("\n",'')

if len(tag) < 1:
if model_type == 'v1/chat/completions':

Expand All @@ -99,7 +101,7 @@ def get_context(tag: str,

remaining_tokens = max_context_length - (len(question.split()))
if remaining_tokens > 0:
question = ' '.join(additional_context.split()[:remaining_tokens]) + question
question = ' '.join(additional_context.split()[:remaining_tokens]) + " " + question


if debug:
Expand Down Expand Up @@ -174,7 +176,7 @@ def get_context(tag: str,
# inexplicable responses.
remaining_tokens = max_context_length - (len(context.split()) + len(question.split()))
if remaining_tokens > 0:
context = ' '.join(additional_context.split()[:remaining_tokens]) + context
context = ' '.join(additional_context.split()[:remaining_tokens]) + " " + context


context = context.strip() + ' ' + question
Expand Down
4 changes: 2 additions & 2 deletions gptty/gptty.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ async def create_chat_room(configs=get_config_data(), log_responses:bool=True, c


# this is used when we run the `query` command
async def run_query(questions:list, tag:str, configs=get_config_data(), log_responses:bool=True, config_path=None, verbose:bool=False, return_json:bool=False, quiet:bool=False):
async def run_query(questions:list, tag:str, configs=get_config_data(), additional_context:str="", log_responses:bool=True, config_path=None, verbose:bool=False, return_json:bool=False, quiet:bool=False):

if not os.path.exists(config_path):
click.echo(f"{RED}FAILED to access app config file at {config_path}. Are you sure this is a valid config file? Run `gptty chat --help` for more information.")
Expand Down Expand Up @@ -284,7 +284,7 @@ async def run_query(questions:list, tag:str, configs=get_config_data(), log_resp
# we create the callable wait_graphic task
wait_task = asyncio.create_task(wait_graphic())

fully_contextualized_question = get_context(tag, configs['max_context_length'], configs['output_file'], model_engine, context_keywords_only=configs['context_keywords_only'], model_type=model_type, question=question, debug=verbose)
fully_contextualized_question = get_context(tag, configs['max_context_length'], configs['output_file'], model_engine, additional_context=additional_context, context_keywords_only=configs['context_keywords_only'], model_type=model_type, question=question, debug=verbose)

response_task = asyncio.create_task(fetch_response(fully_contextualized_question, model_engine, max_tokens, temperature, model_type))

Expand Down

0 comments on commit ec4c746

Please sign in to comment.