Skip to content

Commit

Permalink
Added: docstrings for most functions in gptty/ (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
signebedi committed Apr 5, 2023
1 parent 7455d24 commit da02922
Show file tree
Hide file tree
Showing 7 changed files with 246 additions and 59 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Tests
name: tests

on:
push:
Expand Down
3 changes: 1 addition & 2 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
include requirements.txt
include assets/context_example.png
include assets/question_chain_example.png
include assets/*.png
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# gptty

[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/signebedi/gptty/blob/master/LICENSE)
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/signebedi/gptty/blob/master/LICENSE)
[![PyPI version](https://badge.fury.io/py/gptty.svg)](https://pypi.python.org/pypi/gptty)
[![Downloads](https://static.pepy.tech/badge/gptty)](https://pepy.tech/project/gptty)
[![gptty tests](https://github.com/signebedi/gptty/workflows/Tests/badge.svg)](https://github.com/signebedi/gptty/actions)
[![gptty tests](https://github.com/signebedi/gptty/workflows/tests/badge.svg)](https://github.com/signebedi/gptty/actions)

ChatGPT wrapper in your TTY

Expand Down
24 changes: 24 additions & 0 deletions gptty/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,30 @@

# parse config data
def get_config_data(config_file='gptty.ini'):

"""
The get_config_data() function reads a configuration file and returns a dictionary containing the parsed data.
If the configuration file does not exist or does not contain a particular key, a default value is used.
The default configuration values are defined in the function itself.
By default, the function reads the gptty.ini configuration file, but you can specify a different file name by passing it as an argument.
The returned dictionary has the following keys:
api_key: An API key used to authenticate with the OpenAI API.
your_name: The name of the user (you) who is running the program.
gpt_name: The name of the GPT model to use.
output_file: The name of the output file to write the generated text to.
model: The ID of the GPT model to use.
temperature: The temperature value to use when generating text.
max_tokens: The maximum number of tokens to generate in the generated text.
max_context_length: The maximum number of tokens to use as context when generating text.
context_keywords_only: A boolean value indicating whether to use only the keywords in the context when generating text.
preserve_new_lines: A boolean value indicating whether to preserve new lines in the generated text.
verify_internet_endpoint: The internet endpoint to use when verifying the internet connection.
Note: This function uses the configparser module to parse configuration files.
"""

# create a configuration object
config = configparser.ConfigParser()

Expand Down
65 changes: 58 additions & 7 deletions gptty/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,38 @@
YELLOW = "\033[1;33m"
RESET = "\033[0m"

def verify_added_phrase(phrase:str,context:str, max_len:int) -> bool:
def get_token_count(s, model_name):

if len(context) + len(phrase) <= max_len:
return True
"""
Returns the number of tokens in a text string encoded using a specified model.
return False
Args:
s (str): The input text string.
model_name (str): The name of the model used for encoding.
Returns:
num_tokens (int): The number of tokens in the encoded text string.
"""

def get_token_count(s, model_name):
"""Returns the number of tokens in a text string."""
encoding = tiktoken.encoding_for_model(model_name)
num_tokens = len(encoding.encode(s))
return num_tokens

def return_most_common_phrases(text:str, weight_recent=True) -> list:

"""
Returns a list of the most common noun phrases in the input text, with an option to weight more recent phrases more heavily.
Args:
- text (str): The input text.
- weight_recent (bool): If True, more recent phrases are weighted more heavily.
Returns:
- list: A list of the most common noun phrases in the input text. Each item in the list is a string representing a noun phrase.
"""

# Extract noun phrases using TextBlob
blob = TextBlob(text)
noun_phrases = blob.noun_phrases
Expand Down Expand Up @@ -74,7 +91,41 @@ def get_context(tag: str,
question: str = None,
debug: bool = False):

# additional_context = additional_context.replace("\n",'')

"""
Returns a full query context for a given tag, question and additional context.
Parameters:
tag: str
Tag to identify a conversation with a specific topic
max_context_length: int
Maximum length of the context to return.
output_file: str
Path to the file to read the context from.
model_name: str
Name of the language model to use
context_keywords_only: bool, optional
If True, use only the most common phrases and words from the context and additional context.
Default is True.
additional_context: str, optional
Additional context to add to the context.
Default is an empty string.
model_type: str, optional
Type of the language model. If 'v1/chat/completions', return a list of dicts with 'role' and 'content' keys
If not, return a string.
Default is None.
question: str, optional
Question to add to the context. If None, return only the context.
Default is None.
debug: bool, optional
If True, print debug information.
Default is False.
Returns:
If `model_type` is 'v1/chat/completions', returns a list of dicts with 'role' and 'content' keys
If not, returns a string.
"""


if len(tag) < 1:
if model_type == 'v1/chat/completions':
Expand Down
100 changes: 100 additions & 0 deletions gptty/gptty.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,45 @@
## VALIDATE MODELS - these functions are use to validate the model passed by the user and raises an exception if
## the model does not exist.
def get_available_models():

"""
Returns:
- List: list of available OpenAI model IDs.
"""

response = openai.Model.list()
return [model.id for model in response['data']]

def is_valid_model(model_name):
"""
Validates whether a given model name is available in the OpenAI platform.
Parameters:
- model_name (str): The name of the model to validate.
Returns:
- bool: True if the model name is available, False otherwise.
"""

available_models = get_available_models()
return model_name in available_models

def validate_model_type(model_name):

"""
Validates whether a given model name is a supported model type for OpenAI API completion requests.
Parameters:
- model_name (str): The name of the model to validate.
Returns:
- str: The API endpoint to use for completion requests if the model name is valid and supported.
Raises:
- Exception: If the model name is not valid or not supported.
"""

if ('davinci' in model_name or 'curie' in model_name) and is_valid_model(model_name):
return 'v1/completions'
elif 'gpt' in model_name and is_valid_model(model_name):
Expand All @@ -66,6 +97,23 @@ def validate_model_type(model_name):
# here we define the async call to the openai API that is used when running queries
async def fetch_response(prompt, model_engine, max_tokens, temperature, model_type):

"""
This module provides a function to fetch a response from the OpenAI API based on the given prompt and model specifications.
Parameters:
- prompt (str): The prompt to use for the API request.
- model_engine (str): The engine ID to use for the API request.
- max_tokens (int): The maximum number of tokens to generate in the response.
- temperature (float): The temperature to use for the API request.
- model_type (str): The API endpoint to use for the API request.
Returns:
- OpenAICompletion: The completion response object from the OpenAI API.
Raises:
- Exception: If the model type is not recognized or supported.
"""

if model_type == 'v1/completions':

return await openai.Completion.acreate(
Expand Down Expand Up @@ -98,6 +146,14 @@ async def fetch_response(prompt, model_engine, max_tokens, temperature, model_ty

# here we design the wait graphic that is called while awaiting responses
async def wait_graphic():

"""
This module provides a function to display a wait graphic while awaiting responses.
Returns:
- None
"""

while True:
# for i in range(1, 11):
# print("." * i + " " * (9 - i), end="", flush=True)
Expand All @@ -110,9 +166,27 @@ async def wait_graphic():
await asyncio.sleep(0.1)
print("\b" * 10, end="", flush=True)



# this is used when we run the `chat` command
async def create_chat_room(configs=get_config_data(), log_responses:bool=True, config_path=None, verbose:bool=False):

"""
This function creates a chat room using the OpenAI API to generate responses to user inputs.
The user input is prompted and the response is displayed on the console.
The chat session is continuously open until the user enters ':quit' or ':q' to terminate the session.
The session log is stored in a csv file.
Parameters:
- configs: A dictionary containing OpenAI API key, model name, temperature, max_tokens, max_context_length, context_keywords_only, preserve_new_lines, gpt_name and your_name.
- log_responses: A boolean indicating whether or not to log the responses in a csv file. Default is True.
- config_path: The path to the configuration file.
- verbose: A boolean indicating whether or not to print debugging information. Default is False.
Returns:
- None
"""

# Authenticate with OpenAI using your API key
# click.echo (configs['api_key'])
if configs['api_key'].rstrip('\n') == "":
Expand Down Expand Up @@ -230,6 +304,32 @@ async def create_chat_room(configs=get_config_data(), log_responses:bool=True, c
# this is used when we run the `query` command
async def run_query(questions:list, tag:str, configs=get_config_data(), additional_context:str="", log_responses:bool=True, config_path=None, verbose:bool=False, return_json:bool=False, quiet:bool=False):

"""
This function is used to run a query command using OpenAI.
It takes in a list of questions, a tag, additional context, and various configuration options.
It authenticates with OpenAI using the API key specified in the configuration file, and then continuously sends and receives messages until all the questions
have been answered. The responses are either printed to the console in color or returned in a JSON format, depending on the options specified. Additionally,
the function logs the questions and responses in a pandas dataframe if specified in the configuration file.
Parameters:
questions (list): a list of questions to ask the GPT-3 model
tag (str): a tag to associate with the questions and responses
configs (dict): a dictionary containing configuration options (default: get_config_data())
additional_context (str): additional context to provide to the GPT-3 model (default: "")
log_responses (bool): whether to log the questions and responses in a pandas dataframe (default: True)
config_path (str): the path to the configuration file (default: None)
verbose (bool): whether to enable debug mode (default: False)
return_json (bool): whether to return the responses in a JSON format (default: False)
quiet (bool): whether to suppress console output (default: False)
Returns:
None if the function fails to authenticate with OpenAI or if there are no questions to ask
if return_json is True and quiet is False, prints a JSON representation of the question/response data to the console and returns None
if return_json is True and quiet is True, returns a JSON representation of the question/response data
if return_json is False and quiet is False, prints the responses to the console in color and returns None
if return_json is False and quiet is True, returns None
"""

if not os.path.exists(config_path):
click.echo(f"{RED}FAILED to access app config file at {config_path}. Are you sure this is a valid config file? Run `gptty chat --help` for more information.")
return
Expand Down
Loading

0 comments on commit da02922

Please sign in to comment.