obs-av1.c 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. // SPDX-FileCopyrightText: 2023 David Rosca <[email protected]>
  2. //
  3. // SPDX-License-Identifier: GPL-2.0-or-later
  4. #include "obs-av1.h"
  5. #include "obs.h"
  6. static inline uint64_t leb128(const uint8_t *buf, size_t size, size_t *len)
  7. {
  8. uint64_t value = 0;
  9. uint8_t leb128_byte;
  10. *len = 0;
  11. for (int i = 0; i < 8; i++) {
  12. if (size-- < 1)
  13. break;
  14. (*len)++;
  15. leb128_byte = buf[i];
  16. value |= (leb128_byte & 0x7f) << (i * 7);
  17. if (!(leb128_byte & 0x80))
  18. break;
  19. }
  20. return value;
  21. }
  22. static inline unsigned int get_bits(uint8_t val, unsigned int n,
  23. unsigned int count)
  24. {
  25. return (val >> (8 - n - count)) & ((1 << (count - 1)) * 2 - 1);
  26. }
  27. static void parse_obu_header(const uint8_t *buf, size_t size, size_t *obu_start,
  28. size_t *obu_size, int *obu_type)
  29. {
  30. int extension_flag, has_size_field;
  31. size_t size_len = 0;
  32. *obu_start = 0;
  33. *obu_size = 0;
  34. *obu_type = 0;
  35. if (size < 1)
  36. return;
  37. *obu_type = get_bits(*buf, 1, 4);
  38. extension_flag = get_bits(*buf, 5, 1);
  39. has_size_field = get_bits(*buf, 6, 1);
  40. if (extension_flag)
  41. (*obu_start)++;
  42. (*obu_start)++;
  43. if (has_size_field)
  44. *obu_size = (size_t)leb128(buf + *obu_start, size - *obu_start,
  45. &size_len);
  46. else
  47. *obu_size = size - 1;
  48. *obu_start += size_len;
  49. }
  50. bool obs_av1_keyframe(const uint8_t *data, size_t size)
  51. {
  52. const uint8_t *start = data, *end = data + size;
  53. while (start < end) {
  54. size_t obu_start, obu_size;
  55. int obu_type;
  56. parse_obu_header(start, end - start, &obu_start, &obu_size,
  57. &obu_type);
  58. if (obu_size) {
  59. if (obu_type == OBS_OBU_FRAME ||
  60. obu_type == OBS_OBU_FRAME_HEADER) {
  61. uint8_t val = *(start + obu_start);
  62. if (!get_bits(val, 0, 1)) // show_existing_frame
  63. return get_bits(val, 1, 2) ==
  64. 0; // frame_type
  65. return false;
  66. }
  67. }
  68. start += obu_start + obu_size;
  69. }
  70. return false;
  71. }
  72. void obs_extract_av1_headers(const uint8_t *packet, size_t size,
  73. uint8_t **new_packet_data, size_t *new_packet_size,
  74. uint8_t **header_data, size_t *header_size)
  75. {
  76. DARRAY(uint8_t) new_packet;
  77. DARRAY(uint8_t) header;
  78. const uint8_t *start = packet, *end = packet + size;
  79. da_init(new_packet);
  80. da_init(header);
  81. while (start < end) {
  82. size_t obu_start, obu_size;
  83. int obu_type;
  84. parse_obu_header(start, end - start, &obu_start, &obu_size,
  85. &obu_type);
  86. if (obu_type == OBS_OBU_METADATA ||
  87. obu_type == OBS_OBU_SEQUENCE_HEADER) {
  88. da_push_back_array(header, start, obu_start + obu_size);
  89. }
  90. da_push_back_array(new_packet, start, obu_start + obu_size);
  91. start += obu_start + obu_size;
  92. }
  93. *new_packet_data = new_packet.array;
  94. *new_packet_size = new_packet.num;
  95. *header_data = header.array;
  96. *header_size = header.num;
  97. }