-
Notifications
You must be signed in to change notification settings - Fork 4.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
在使用rebel提取实体的时候发生了下面的警告,且好像rebel并没有起到作用,在进行检索和问答时,我的图数据库中没有任何实体和关系,所以没有结果返回。 #14621
Comments
要解决弃用警告并确保正确提取和存储实体和关系,请移除 from llama_index.core import SimpleDirectoryReader, KnowledgeGraphIndex, StorageContext, Settings
import nest_asyncio
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.core import PromptTemplate
from llama_index.graph_stores.neo4j import Neo4jGraphStore
import torch
from transformers import pipeline
nest_asyncio.apply()
def extract_triplets(text):
triplets = []
relation, subject, relation, object_ = '', '', '', ''
text = text.strip()
current = 'x'
for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
if token == "<triplet>":
current = 't'
if relation != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})
relation = ''
subject = ''
elif token == "<subj>":
current = 's'
if relation != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})
object_ = ''
elif token == "<obj>":
current = 'o'
relation = ''
else:
if current == 't':
subject += ' ' + token
elif current == 's':
object_ += ' ' + token
elif current == 'o':
relation += ' ' + token
if subject != '' and relation != '' and object_ != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})
return triplets
# Load documents
documents = SimpleDirectoryReader('/home/cdhd/PythonCode/pythonProject/TestFile/公司主要人员架构类').load_data()
print("documents", documents[0].text)
SYSTEM_PROMPT = """You are a helpful AI assistant.
"""
query_wrapper_prompt = PromptTemplate(
"[INST]<<SYS>>\n" + SYSTEM_PROMPT + "<</SYS>>\n\n{query_str}[/INST] "
)
llm = HuggingFaceLLM(
context_window=4096,
max_new_tokens=1024,
generate_kwargs={"temperature": 0.2, "do_sample": True},
query_wrapper_prompt=query_wrapper_prompt,
tokenizer_name='/home/cdhd/Desktop/glm4/glm-4-9b-chat',
model_name='/home/cdhd/Desktop/glm4/glm-4-9b-chat',
device_map="auto",
model_kwargs={"torch_dtype": torch.float16},
)
Settings.llm = llm
triplet_extractor = pipeline('text2text-generation', model='/home/cdhd/PythonCode/pythonProject/model/rebel', tokenizer='/home/cdhd/PythonCode/pythonProject/model/rebel')
extracted_text = triplet_extractor.tokenizer.batch_decode([triplet_extractor(documents[0].text, return_tensors=True, return_text=False)[0]["generated_token_ids"]])
embed_model = HuggingFaceEmbedding(
model_name='/home/cdhd/PythonCode/pythonProject/model/bge-large-zh-v1.5/models--BAAI--bge-large-zh-v1.5/snapshots/79e7739b6ab944e86d6171e44d24c997fc1e0116',
)
Settings.embed_model = embed_model
graph_store = Neo4jGraphStore(
username="neo4j",
password="12345678",
url="bolt://localhost:7687",
database="neo4j"
)
# 清理数据集
graph_store.query(
"""
MATCH (n) DETACH DELETE n
"""
)
# 创建存储上下文
storage_context = StorageContext.from_defaults(graph_store=graph_store)
extracted_triplets = extract_triplets(extracted_text[0])
# 构建包含嵌入的索引
index = KnowledgeGraphIndex.from_documents(
documents,
kg_triplet_extract_fn=extract_triplets,
storage_context=storage_context,
max_triplets_per_chunk=2,
include_embeddings=True,
embed_model=embed_model,
llm=llm,
)
retriever = index.as_retriever()
ans = retriever.retrieve("公司的董事长是谁")
print(ans)
for node in ans:
print(node.text)
query_engine = index.as_query_engine(
include_text=True,
response_mode="tree_summarize",
llm=llm
)
response = query_engine.query(
"公司的董事长是谁?"
)
print("response:", response) 此更新移除了 |
我这个程序执行,rebel依然没有成功抽取出实体,neo4j图数据中依然没有任何实体和关系,这是怎么回事儿,有其他的解决方法吗? |
可以尝试以下解决方法:
以下是 def extract_triplets(input_text):
text = triplet_extractor.tokenizer.batch_decode(
[
triplet_extractor(
input_text, return_tensors=True, return_text=False
)[0]["generated_token_ids"]
]
)[0]
triplets = []
relation, subject, relation, object_ = "", "", "", ""
text = text.strip()
current = "x"
for token in (
text.replace("<s>", "")
.replace("<pad>", "")
.replace("</s>", "")
.split()
):
if token == "<triplet>":
current = "t"
if relation != "":
triplets.append(
(subject.strip(), relation.strip(), object_.strip())
)
relation = ""
subject = ""
elif token == "<subj>":
current = "s"
if relation != "":
triplets.append(
(subject.strip(), relation.strip(), object_.strip())
)
object_ = ""
elif token == "<obj>":
current = "o"
relation = ""
else:
if current == "t":
subject += " " + token
elif current == "s":
object_ += " " + token
elif current == "o":
relation += " " + token
if subject != "" and relation != "" and object_ != "":
triplets.append((subject.strip(), relation.strip(), object_.strip()))
return triplets 如果使用了 def extract_triplets_wiki(text):
relations = extract_triplets(text)
filtered_relations = []
for relation in relations:
(subj, rel, obj) = relation
filtered_subj = wiki_filter.filter(subj)
filtered_obj = wiki_filter.filter(obj)
# skip if at least one entity not linked to wiki
if filtered_subj is None and filtered_obj is None:
continue
filtered_relations.append(
(
filtered_subj or subj,
rel,
filtered_obj or obj,
)
)
return filtered_relations 确保 class WikiFilter:
def __init__(self):
self.cache = {}
def filter(self, candidate_entity):
# check the cache to avoid network calls
if candidate_entity in self.cache:
return self.cache[candidate_entity]["title"]
# pull the page from wikipedia -- if it exists
try:
page = wikipedia.page(candidate_entity, auto_suggest=False)
entity_data = {
"title": page.title,
"url": page.url,
"summary": page.summary,
}
# cache the page title and original entity
self.cache[candidate_entity] = entity_data
self.cache[page.title] = entity_data
return entity_data["title"]
except:
return None 如果以上检查都没有问题,建议检查从 ArxivReader 加载的数据是否包含有效的文本[1]。 |
Question Validation
Question
警告:DeprecationWarning: Call to deprecated class method from_defaults. (ServiceContext is deprecated, please use
llama_index.settings.Settings
instead.) -- Deprecated since version 0.10.0.service_context = ServiceContext.from_defaults(llm=llm,embed_model=embed_model,chunk_size=256)
代码如下:
from llama_index.core import SimpleDirectoryReader,KnowledgeGraphIndex,StorageContext,Settings,ServiceContext
import nest_asyncio
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.core import PromptTemplate
from llama_index.graph_stores.neo4j import Neo4jGraphStore
import torch
from transformers import pipeline
import yapf
nest_asyncio.apply()
def extract_triplets(text):
triplets = []
relation, subject, relation, object_ = '', '', '', ''
text = text.strip()
current = 'x'
for token in text.replace("
", "").replace("", "").replace("", "").split():if token == "":
current = 't'
if relation != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
relation = ''
subject = ''
elif token == "":
current = 's'
if relation != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
object_ = ''
elif token == "":
current = 'o'
relation = ''
else:
if current == 't':
subject += ' ' + token
elif current == 's':
object_ += ' ' + token
elif current == 'o':
relation += ' ' + token
if subject != '' and relation != '' and object_ != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
return triplets
Load documents
documents = SimpleDirectoryReader('/home/cdhd/PythonCode/pythonProject/TestFile/公司主要人员架构类').load_data()
print("documents",documents[0].text)
model_config = {"protected_namespaces": ()}
SYSTEM_PROMPT = """You are a helpful AI assistant.
"""
query_wrapper_prompt = PromptTemplate(
"[INST]<>\n" + SYSTEM_PROMPT + "<>\n\n{query_str}[/INST] "
)
llm = HuggingFaceLLM(
context_window=4096,
max_new_tokens=1024,
generate_kwargs={"temperature": 0.2, "do_sample": True},
query_wrapper_prompt=query_wrapper_prompt,
# tokenizer_name='/home/cdhd/PythonCode/pythonProject/model/gpt2',
# model_name='/home/cdhd/PythonCode/pythonProject/model/gpt2',
tokenizer_name='/home/cdhd/Desktop/glm4/glm-4-9b-chat',
model_name='/home/cdhd/Desktop/glm4/glm-4-9b-chat',
device_map="auto",
# change these settings below depending on your GPU
model_kwargs={"torch_dtype": torch.float16},
)
Settings.llm = llm
triplet_extractor = pipeline('text2text-generation', model='/home/cdhd/PythonCode/pythonProject/model/rebel', tokenizer='/home/cdhd/PythonCode/pythonProject/model/rebel')
llm = AutoModel.from_pretrained('/home/cdhd/Desktop/glm4/glm-4-9b-chat', trust_remote_code=True)
We need to use the tokenizer manually since we need special tokens.
extracted_text = triplet_extractor.tokenizer.batch_decode([triplet_extractor(documents[0].text, return_tensors=True, return_text=False)[0]["generated_token_ids"]])
embed_model = HuggingFaceEmbedding(
model_name='/home/cdhd/PythonCode/pythonProject/model/bge-large-zh-v1.5/models--BAAI--bge-large-zh-v1.5/snapshots/79e7739b6ab944e86d6171e44d24c997fc1e0116',
# model_name = r"E:\Huggingface_model\bge-large-zh-v1.5",
# cache_folder=r"E:\Huggingface_model\bge-large-zh-v1.5\models--BAAI--bge-large-zh-v1.5\snapshots\79e7739b6ab944e86d6171e44d24c997fc1e0116",
)
Settings.embed_model = embed_model
graph_store = Neo4jGraphStore(
username="neo4j",
password="12345678",
url="bolt://localhost:7687",
database="neo4j"
)
service_context = ServiceContext.from_defaults(llm=llm,embed_model=embed_model,chunk_size=256)
清理数据集
graph_store.query(
"""
MATCH (n) DETACH DELETE n
"""
)
创建存储上下文
storage_context = StorageContext.from_defaults(graph_store=graph_store)
extracted_triplets = extract_triplets(extracted_text[0])
构建包含嵌入的索引
index = KnowledgeGraphIndex.from_documents(
documents,
kg_triplet_extract_fn=extract_triplets,
service_context=service_context,
storage_context=storage_context,
max_triplets_per_chunk=2,
include_embeddings=True,
embed_model=embed_model,
llm=llm,
)
retriever = index.as_retriever()
ans = retriever.retrieve("公司的董事长是谁")
print(ans)
for node in ans:
print(node.text)
query_engine = index.as_query_engine(
include_text = True,
response_mode="tree_summarize",
llm=llm
)
response = query_engine.query(
"公司的董事长是谁?"
)
print("response:",response)
The text was updated successfully, but these errors were encountered: