| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- """
- References: https://docs.llamaindex.ai/en/stable/use_cases/q_and_a/
- """
- import argparse
- import gradio as gr
- from llama_index.core import Settings
- from models.embedding import GLMEmbeddings
- from models.synthesizer import CodegeexSynthesizer
- from utils.vector import load_vectors
- def parse_arguments():
- parser = argparse.ArgumentParser()
- parser.add_argument('--vector_path', type=str, help="path to store the vectors", default='vectors')
- parser.add_argument('--model_name_or_path', type=str, default='THUDM/codegeex4-all-9b')
- parser.add_argument('--device', type=str, help="cpu or cuda", default="cpu")
- parser.add_argument('--temperature', type=float, help="model's temperature", default=0.2)
- return parser.parse_args()
- def chat(query, history):
- resp = query_engine.query(query)
- ans = "相关文档".center(150, '-') + '\n'
- yield ans
- for i, node in enumerate(resp.source_nodes):
- file_name = node.metadata['filename']
- ext = node.metadata['extension']
- text = node.text
- ans += f"File{i + 1}: {file_name}\n```{ext}\n{text}\n```\n"
- yield ans
- ans += "模型回复".center(150, '-') + '\n'
- ans += resp.response
- yield ans
- if __name__ == '__main__':
- args = parse_arguments()
- Settings.embed_model = GLMEmbeddings()
- try:
- query_engine = load_vectors(args.vector_path).as_query_engine(
- response_synthesizer=CodegeexSynthesizer(args)
- )
- except Exception as e:
- print(f"Fail to load vectors, caused by {e}")
- exit()
- demo = gr.ChatInterface(chat).queue()
- demo.launch(server_name="127.0.0.1", server_port=8080)
|