chat.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. """
  2. References: https://docs.llamaindex.ai/en/stable/use_cases/q_and_a/
  3. """
  4. import argparse
  5. import gradio as gr
  6. from llama_index.core import Settings
  7. from models.embedding import GLMEmbeddings
  8. from models.synthesizer import CodegeexSynthesizer
  9. from utils.vector import load_vectors
  10. def parse_arguments():
  11. parser = argparse.ArgumentParser()
  12. parser.add_argument('--vector_path', type=str, help="path to store the vectors", default='vectors')
  13. parser.add_argument('--model_name_or_path', type=str, default='THUDM/codegeex4-all-9b')
  14. parser.add_argument('--device', type=str, help="cpu or cuda", default="cpu")
  15. parser.add_argument('--temperature', type=float, help="model's temperature", default=0.2)
  16. return parser.parse_args()
  17. def chat(query, history):
  18. resp = query_engine.query(query)
  19. ans = "相关文档".center(150, '-') + '\n'
  20. yield ans
  21. for i, node in enumerate(resp.source_nodes):
  22. file_name = node.metadata['filename']
  23. ext = node.metadata['extension']
  24. text = node.text
  25. ans += f"File{i + 1}: {file_name}\n```{ext}\n{text}\n```\n"
  26. yield ans
  27. ans += "模型回复".center(150, '-') + '\n'
  28. ans += resp.response
  29. yield ans
  30. if __name__ == '__main__':
  31. args = parse_arguments()
  32. Settings.embed_model = GLMEmbeddings()
  33. try:
  34. query_engine = load_vectors(args.vector_path).as_query_engine(
  35. response_synthesizer=CodegeexSynthesizer(args)
  36. )
  37. except Exception as e:
  38. print(f"Fail to load vectors, caused by {e}")
  39. exit()
  40. demo = gr.ChatInterface(chat).queue()
  41. demo.launch(server_name="127.0.0.1", server_port=8080)