fixer.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. package sub_timeline_fixer
  2. import (
  3. "errors"
  4. "fmt"
  5. "github.com/allanpk716/ChineseSubFinder/internal/common"
  6. "github.com/allanpk716/ChineseSubFinder/internal/logic/sub_parser/ass"
  7. "github.com/allanpk716/ChineseSubFinder/internal/logic/sub_parser/srt"
  8. "github.com/allanpk716/ChineseSubFinder/internal/pkg/sub_parser_hub"
  9. "github.com/go-echarts/go-echarts/v2/opts"
  10. "github.com/grd/stat"
  11. "github.com/james-bowman/nlp"
  12. "github.com/james-bowman/nlp/measures/pairwise"
  13. "github.com/mndrix/tukey"
  14. "gonum.org/v1/gonum/mat"
  15. "strings"
  16. "time"
  17. )
  18. // StopWordCounter 停止词统计
  19. func StopWordCounter(inString string, per int) []string {
  20. statisticTimes := make(map[string]int)
  21. wordsLength := strings.Fields(inString)
  22. for counts, word := range wordsLength {
  23. // 判断key是否存在,这个word是字符串,这个counts是统计的word的次数。
  24. word, ok := statisticTimes[word]
  25. if ok {
  26. word = word
  27. statisticTimes[wordsLength[counts]] = statisticTimes[wordsLength[counts]] + 1
  28. } else {
  29. statisticTimes[wordsLength[counts]] = 1
  30. }
  31. }
  32. stopWords := make([]string, 0)
  33. mapByValue := sortMapByValue(statisticTimes)
  34. breakIndex := len(mapByValue) * per / 100
  35. for index, wordInfo := range mapByValue {
  36. if index > breakIndex {
  37. break
  38. }
  39. stopWords = append(stopWords, wordInfo.Name)
  40. }
  41. return stopWords
  42. }
  43. // NewTFIDF 初始化 TF-IDF
  44. func NewTFIDF(testCorpus []string) (*nlp.Pipeline, mat.Matrix, error) {
  45. newCountVectoriser := nlp.NewCountVectoriser(StopWords...)
  46. transformer := nlp.NewTfidfTransformer()
  47. // set k (the number of dimensions following truncation) to 4
  48. reducer := nlp.NewTruncatedSVD(4)
  49. lsiPipeline := nlp.NewPipeline(newCountVectoriser, transformer, reducer)
  50. // Transform the corpus into an LSI fitting the model to the documents in the process
  51. lsi, err := lsiPipeline.FitTransform(testCorpus...)
  52. if err != nil {
  53. return nil, lsi, errors.New(fmt.Sprintf("Failed to process testCorpus documents because %v", err))
  54. }
  55. return lsiPipeline, lsi, nil
  56. }
  57. // GetOffsetTime 暂时只支持英文的基准字幕,源字幕必须是双语中英字幕
  58. func GetOffsetTime(baseEngSubFPath, srcSubFPath string) (time.Duration, error) {
  59. subParserHub := sub_parser_hub.NewSubParserHub(ass.NewParser(), srt.NewParser())
  60. bFind, infoBase, err := subParserHub.DetermineFileTypeFromFile(baseEngSubFPath)
  61. if err != nil {
  62. return 0, err
  63. }
  64. if bFind == false {
  65. return 0, nil
  66. }
  67. bFind, infoSrc, err := subParserHub.DetermineFileTypeFromFile(srcSubFPath)
  68. if err != nil {
  69. return 0, err
  70. }
  71. if bFind == false {
  72. return 0, nil
  73. }
  74. // 构建基准语料库,目前阶段只需要考虑是 En 的就行了
  75. var baseCorpus = make([]string, 0)
  76. for _, oneDialogueEx := range infoBase.DialoguesEx {
  77. baseCorpus = append(baseCorpus, oneDialogueEx.EnLine)
  78. }
  79. // 初始化
  80. pipLine, tfidf, err := NewTFIDF(baseCorpus)
  81. if err != nil {
  82. return 0, err
  83. }
  84. /*
  85. 确认两个字幕间的偏移,暂定的方案是两边都连续匹配上 5 个索引,再抽取一个对话的时间进行修正计算
  86. */
  87. maxCompareDialogue := 5
  88. // 基线的长度
  89. _, docsLength := tfidf.Dims()
  90. var matchIndexList = make([]MatchIndex, 0)
  91. sc := NewSubCompare(maxCompareDialogue)
  92. // 开始比较相似度,默认认为是 Ch_en 就行了
  93. for srcIndex, srcOneDialogueEx := range infoSrc.DialoguesEx {
  94. // 这里只考虑 英文 的语言
  95. if srcOneDialogueEx.EnLine == "" {
  96. continue
  97. }
  98. // run the query through the same pipeline that was fitted to the corpus and
  99. // to project it into the same dimensional space
  100. queryVector, err := pipLine.Transform(srcOneDialogueEx.EnLine)
  101. if err != nil {
  102. return 0, err
  103. }
  104. // iterate over document feature vectors (columns) in the LSI matrix and compare
  105. // with the query vector for similarity. Similarity is determined by the difference
  106. // between the angles of the vectors known as the cosine similarity
  107. highestSimilarity := -1.0
  108. // 匹配上的基准的索引
  109. var baseIndex int
  110. // 这里理论上需要把所有的基线遍历一次,但是,一般来说,两个字幕不可能差距在 50 行
  111. // 这样的好处是有助于提高搜索的性能
  112. // 那么就以当前的 src 的位置,向前、向后各 50 来遍历
  113. nowMaxScanLength := srcIndex + 50
  114. nowMinScanLength := srcIndex - 50
  115. if nowMinScanLength < 0 {
  116. nowMinScanLength = 0
  117. }
  118. if nowMaxScanLength > docsLength {
  119. nowMaxScanLength = docsLength
  120. }
  121. for i := nowMinScanLength; i < nowMaxScanLength; i++ {
  122. similarity := pairwise.CosineSimilarity(queryVector.(mat.ColViewer).ColView(0), tfidf.(mat.ColViewer).ColView(i))
  123. if similarity > highestSimilarity {
  124. baseIndex = i
  125. highestSimilarity = similarity
  126. }
  127. }
  128. if sc.Add(baseIndex, srcIndex) == false {
  129. sc.Clear()
  130. sc.Add(baseIndex, srcIndex)
  131. }
  132. if sc.Check() == false {
  133. continue
  134. }
  135. startBaseIndex, startSrcIndex := sc.GetStartIndex()
  136. matchIndexList = append(matchIndexList, MatchIndex{
  137. BaseNowIndex: startBaseIndex,
  138. SrcNowIndex: startSrcIndex,
  139. Similarity: highestSimilarity,
  140. })
  141. //println(fmt.Sprintf("Similarity: %f Base[%d] %s-%s '%s' <--> Src[%d] %s-%s '%s'",
  142. // highestSimilarity,
  143. // baseIndex, infoBase.DialoguesEx[baseIndex].StartTime, infoBase.DialoguesEx[baseIndex].EndTime, baseCorpus[baseIndex],
  144. // srcIndex, srcOneDialogueEx.StartTime, srcOneDialogueEx.EndTime, srcOneDialogueEx.EnLine))
  145. }
  146. timeFormat := ""
  147. if infoBase.Ext == common.SubExtASS || infoBase.Ext == common.SubExtSSA {
  148. timeFormat = timeFormatAss
  149. } else {
  150. timeFormat = timeFormatSrt
  151. }
  152. var startDiffTimeLineData = make([]opts.LineData, 0)
  153. var endDiffTimeLineData = make([]opts.LineData, 0)
  154. var tmpStartDiffTime = make([]float64, 0)
  155. var tmpEndDiffTime = make([]float64, 0)
  156. var startDiffTimeList = make(stat.Float64Slice, 0)
  157. var endDiffTimeList = make(stat.Float64Slice, 0)
  158. var xAxis = make([]string, 0)
  159. // 上面找出了连续匹配 maxCompareDialogue:N 次的字幕语句块
  160. // 求出平均时间偏移
  161. for mIndex, matchIndexItem := range matchIndexList {
  162. for i := 0; i < maxCompareDialogue; i++ {
  163. // 这里会统计连续的这 5 句话的时间差
  164. tmpBaseIndex := matchIndexItem.BaseNowIndex + i
  165. tmpSrcIndex := matchIndexItem.SrcNowIndex + i
  166. baseTimeStart, err := time.Parse(timeFormat, infoBase.DialoguesEx[tmpBaseIndex].StartTime)
  167. if err != nil {
  168. println("baseTimeStart", err)
  169. continue
  170. }
  171. baseTimeEnd, err := time.Parse(timeFormat, infoBase.DialoguesEx[tmpBaseIndex].EndTime)
  172. if err != nil {
  173. println("baseTimeEnd", err)
  174. continue
  175. }
  176. srtTimeStart, err := time.Parse(timeFormat, infoSrc.DialoguesEx[tmpSrcIndex].StartTime)
  177. if err != nil {
  178. println("srtTimeStart", err)
  179. continue
  180. }
  181. srtTimeEnd, err := time.Parse(timeFormat, infoSrc.DialoguesEx[tmpSrcIndex].EndTime)
  182. if err != nil {
  183. println("srtTimeEnd", err)
  184. continue
  185. }
  186. TimeDiffStart := baseTimeStart.Sub(srtTimeStart)
  187. TimeDiffEnd := baseTimeEnd.Sub(srtTimeEnd)
  188. startDiffTimeLineData = append(startDiffTimeLineData, opts.LineData{Value: TimeDiffStart.Seconds()})
  189. endDiffTimeLineData = append(endDiffTimeLineData, opts.LineData{Value: TimeDiffEnd.Seconds()})
  190. tmpStartDiffTime = append(tmpStartDiffTime, TimeDiffStart.Seconds())
  191. tmpEndDiffTime = append(tmpEndDiffTime, TimeDiffEnd.Seconds())
  192. startDiffTimeList = append(startDiffTimeList, TimeDiffStart.Seconds())
  193. endDiffTimeList = append(endDiffTimeList, TimeDiffEnd.Seconds())
  194. xAxis = append(xAxis, fmt.Sprintf("%d_%d", mIndex, i))
  195. //println(fmt.Sprintf("Diff Start-End: %s - %s Base[%d] %s-%s '%s' <--> Src[%d] %s-%s '%s'",
  196. // TimeDiffStart, TimeDiffEnd,
  197. // tmpBaseIndex, infoBase.DialoguesEx[tmpBaseIndex].StartTime, infoBase.DialoguesEx[tmpBaseIndex].EndTime, infoBase.DialoguesEx[tmpBaseIndex].EnLine,
  198. // tmpSrcIndex, infoSrc.DialoguesEx[tmpSrcIndex].StartTime, infoSrc.DialoguesEx[tmpSrcIndex].EndTime, infoSrc.DialoguesEx[tmpSrcIndex].EnLine))
  199. }
  200. //println("---------------------------------------------")
  201. }
  202. oldMean := stat.Mean(startDiffTimeList)
  203. oldSd := stat.Sd(startDiffTimeList)
  204. newMean := -1.0
  205. newSd := -1.0
  206. per := 1.0
  207. // 如果 SD 较大的时候才需要剔除
  208. if oldSd > 0.1 {
  209. var outliersMap = make(map[float64]int, 0)
  210. outliers, _, _ := tukey.Outliers(0.1, tmpStartDiffTime)
  211. for _, outlier := range outliers {
  212. outliersMap[outlier] = 0
  213. }
  214. var newStartDiffTimeList = make([]float64, 0)
  215. for _, f := range tmpStartDiffTime {
  216. _, ok := outliersMap[f]
  217. if ok == true {
  218. continue
  219. }
  220. newStartDiffTimeList = append(newStartDiffTimeList, f)
  221. }
  222. orgLen := startDiffTimeList.Len()
  223. startDiffTimeList = make(stat.Float64Slice, 0)
  224. for _, f := range newStartDiffTimeList {
  225. startDiffTimeList = append(startDiffTimeList, f)
  226. }
  227. newLen := startDiffTimeList.Len()
  228. per = float64(newLen) / float64(orgLen)
  229. newMean = stat.Mean(startDiffTimeList)
  230. newSd = stat.Sd(startDiffTimeList)
  231. }
  232. if newMean == -1.0 {
  233. newMean = oldMean
  234. }
  235. if newSd == -1.0 {
  236. newSd = oldSd
  237. }
  238. err = SaveStaticLine("bar.html", infoBase.Name, infoSrc.Name,
  239. per, oldMean, oldSd, newMean, newSd, xAxis,
  240. startDiffTimeLineData, endDiffTimeLineData)
  241. if err != nil {
  242. return 0, err
  243. }
  244. return 0, nil
  245. }
  246. const timeFormatAss = "15:04:05.00"
  247. const timeFormatSrt = "15:04:05,000"