Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: _collection attribute of ChromaVectorStore not copied during a deepcopy (when using DSPy) #14570

Open
theta-lin opened this issue Jul 4, 2024 · 2 comments
Labels
bug Something isn't working triage Issue needs to be triaged/prioritized

Comments

@theta-lin
Copy link
Contributor

Bug Description

When using VectorIndexRetriever created from ChromaVectorStore inside of a DSPy module, compiling that module would fail. The root cause of the issue appears to be that DSPy deepcopies a module before running it, however, the _collection attribute of ChromaVectorStore is not copied during a deepcopy.

Here is a monkey patch that fixes this issue:

def mydeepcopy(self, memo):
    return self


import llama_index
llama_index.vector_stores.chroma.ChromaVectorStore.__deepcopy__ = mydeepcopy

Version

llama-index==0.10.51; llama-index-vector-stores-chroma==0.1.10

Steps to Reproduce

Minimum reproducible example:

import chromadb
from llama_index.vector_stores.chroma import ChromaVectorStore
from copy import deepcopy

db = chromadb.PersistentClient(path="./chroma_db")
chroma_collection = db.get_or_create_collection("my_collection")
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
print(vector_store._collection)

c = deepcopy(vector_store)
print(c._collection)

Example for integration with DSPy:

class Rag(dspy.Module):
    def __init__(self):
        super().__init__()

        db = chromadb.PersistentClient(path="./chroma_db")
        chroma_collection = db.get_or_create_collection("my_collection")
        vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
        index = VectorStoreIndex.from_vector_store(vector_store)
        self.vector_retriever = index.as_retriever(similarity_top_k=5)

    def forward(self, question):
        nodes = self.vector_retriever.retrieve(question)
        return dspy.Prediction(answer=str(nodes))

Relevant Logs/Tracbacks

Traceback (most recent call last):
  File "/code/./query.py", line 262, in <module>
    main()
  File "/code/./query.py", line 228, in main
    rag = teleprompter.compile(rag_with_assertions, trainset=trainset)
  File "/code/.venv/lib/python3.10/site-packages/dspy/teleprompt/bootstrap.py", line 84, in compile
    self._bootstrap()
  File "/code/.venv/lib/python3.10/site-packages/dspy/teleprompt/bootstrap.py", line 154, in _bootstrap
    success = self._bootstrap_one_example(example, round_idx)
  File "/code/.venv/lib/python3.10/site-packages/dspy/teleprompt/bootstrap.py", line 210, in _bootstrap_one_example
    raise e
  File "/code/.venv/lib/python3.10/site-packages/dspy/teleprompt/bootstrap.py", line 190, in _bootstrap_one_example
    prediction = teacher(**example.inputs())
  File "/code/.venv/lib/python3.10/site-packages/dspy/primitives/program.py", line 26, in __call__
    return self.forward(*args, **kwargs)
  File "/code/.venv/lib/python3.10/site-packages/dspy/primitives/assertions.py", line 294, in forward
    return wrapped_forward(*args, **kwargs)
  File "/code/.venv/lib/python3.10/site-packages/dspy/primitives/assertions.py", line 215, in wrapper
    result = bypass_suggest_handler(func)(*args, **kwargs) if bypass_suggest else None
  File "/code/.venv/lib/python3.10/site-packages/dspy/primitives/assertions.py", line 148, in wrapper
    return func(*args, **kwargs)
  File "/code/./query.py", line 161, in forward
    nodes = retriever.retrieve(question)
  File "/code/.venv/lib/python3.10/site-packages/llama_index/core/instrumentation/dispatcher.py", line 230, in wrapper
    result = func(*args, **kwargs)
  File "/code/.venv/lib/python3.10/site-packages/llama_index/core/base/base_retriever.py", line 243, in retrieve
    nodes = self._retrieve(query_bundle)
  File "/code/.venv/lib/python3.10/site-packages/llama_index/core/instrumentation/dispatcher.py", line 230, in wrapper
    result = func(*args, **kwargs)
  File "/code/.venv/lib/python3.10/site-packages/llama_index/core/indices/vector_store/retrievers/retriever.py", line 101, in _retrieve
    return self._get_nodes_with_embeddings(query_bundle)
  File "/code/.venv/lib/python3.10/site-packages/llama_index/core/indices/vector_store/retrievers/retriever.py", line 101, in [189/1909]
    return self._get_nodes_with_embeddings(query_bundle)
  File "/code/.venv/lib/python3.10/site-packages/llama_index/core/indices/vector_store/retrievers/retriever.py", line 177, in _get_nodes
_with_embeddings
    query_result = self._vector_store.query(query, **self._kwargs)
  File "/code/.venv/lib/python3.10/site-packages/llama_index/vector_stores/chroma/base.py", line 371, in query
    return self._query(
  File "/code/.venv/lib/python3.10/site-packages/llama_index/vector_stores/chroma/base.py", line 381, in _query
    results = self._collection.query(
AttributeError: 'ChromaVectorStore' object has no attribute '_collection'. Did you mean: 'from_collection'?
@theta-lin theta-lin added bug Something isn't working triage Issue needs to be triaged/prioritized labels Jul 4, 2024
@logan-markewich
Copy link
Collaborator

@theta-lin would love a suggestion for a fix besides monkeypatching

@theta-lin
Copy link
Contributor Author

@logan-markewich Indeed, mokeypatching is just my temporary workaround.

After further investigation, I think the root cause is that an sqlite3.Connection object is not picklable, which makes sense as that a database connection should be shared instead of being copied. The issue could be illustrated with the following code:

class C(BaseModel):
    _collection = PrivateAttr()

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        db = chromadb.PersistentClient(path="./chroma_db")
        self._collection = db.get_or_create_collection("my_collection")


c = C()
print(c._collection)

c_copy = deepcopy(c)
print(c_copy._collection)

It gives the error message TypeError: cannot pickle 'sqlite3.Connection' object.

While I understand that you may change the default behavior of copy.deepcopy() by providing a custom __deepcopy__() method, I am not sure whether directly overriding it here is a good course of action as ChromaVectorStore is a Pydantic object. Additionally, the error message mentioned above is not shown when the actual ChromaVectorStore object is pickled, so I think maybe some of the default copying behavior has already being modified? So what do you think is the best way to fix it in this case?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triage Issue needs to be triaged/prioritized
2 participants