main.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import json
  2. import re
  3. from json import JSONDecodeError
  4. import torch
  5. from transformers import AutoModelForCausalLM, AutoTokenizer
  6. def main():
  7. device = "cuda" if torch.cuda.is_available() else "cpu"
  8. model_name_or_path = "THUDM/codegeex4-all-9b"
  9. tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
  10. model = AutoModelForCausalLM.from_pretrained(
  11. model_name_or_path,
  12. torch_dtype=torch.bfloat16,
  13. trust_remote_code=True
  14. ).to(device).eval()
  15. tool_content = {
  16. "function": [
  17. {
  18. "name": "weather",
  19. "description": "Use for searching weather at a specific location",
  20. "parameters": {
  21. "type": "object",
  22. "properties": {
  23. "location": {
  24. "description": "the location need to check the weather",
  25. "type": "str",
  26. }
  27. },
  28. "required": [
  29. "location"
  30. ]
  31. }
  32. }
  33. ]
  34. }
  35. response, _ = model.chat(
  36. tokenizer,
  37. query="Tell me about the weather in Beijing",
  38. history=[{"role": "tool", "content": tool_content}],
  39. max_new_tokens=1024,
  40. temperature=0.1
  41. )
  42. # support parallel calls, thus the result is a list
  43. functions = post_process(response)
  44. try:
  45. return [json.loads(func) for func in functions if func]
  46. # get rid of some possible invalid formats
  47. except JSONDecodeError:
  48. try:
  49. return [json.loads(func.replace('(', '[').replace(')', ']')) for func in functions if func]
  50. except JSONDecodeError:
  51. try:
  52. return [json.loads(func.replace("'", '"')) for func in functions if func]
  53. except JSONDecodeError as e:
  54. return [{"answer": response, "errors": e}]
  55. def post_process(text: str) -> list[str]:
  56. """
  57. Process model's response.
  58. In case there are parallel calls, each call is warpped with ```json```.
  59. """
  60. pattern = r'```json(.*?)```'
  61. matches = re.findall(pattern, text, re.DOTALL)
  62. return matches
  63. if __name__ == '__main__':
  64. output = main()
  65. print(output) # [{"name": "weather", "arguments": {"location": "Beijing"}}]