gss.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. package gss
  2. import (
  3. "log"
  4. "math"
  5. )
  6. var (
  7. sqrt5 = math.Sqrt(5)
  8. invphi = (sqrt5 - 1) / 2 //# 1/phi
  9. invphi2 = (3 - sqrt5) / 2 //# 1/phi^2
  10. nan = math.NaN()
  11. )
  12. // Gss golden section search (recursive version)
  13. // https://en.wikipedia.org/wiki/Golden-section_search
  14. // https://github.com/pa-m/optimize/blob/master/gss.go
  15. // '''
  16. // Golden section search, recursive.
  17. // Given a function f with a single local minimum in
  18. // the interval [a,b], gss returns a subset interval
  19. // [c,d] that contains the minimum with d-c <= tol.
  20. //
  21. // logger may be nil
  22. //
  23. // example:
  24. // >>> f = lambda x: (x-2)**2
  25. // >>> a = 1
  26. // >>> b = 5
  27. // >>> tol = 1e-5
  28. // >>> (c,d) = gssrec(f, a, b, tol)
  29. // >>> print (c,d)
  30. // (1.9999959837979107, 2.0000050911830893)
  31. // '''
  32. func Gss(fWrapped func(float64, bool) float64, a, b, tol float64, logger *log.Logger) (float64, float64) {
  33. if a > b {
  34. a, b = b, a
  35. }
  36. h := b - a
  37. if h <= tol {
  38. return a, b
  39. }
  40. n := int(math.Ceil(math.Log(tol/h) / math.Log(invphi)))
  41. if logger != nil {
  42. logger.Printf("About to perform %d iterations of golden section search to find the best framerate", n)
  43. }
  44. c := a + invphi2*h
  45. d := a + invphi*h
  46. yc := fWrapped(c, n == 1)
  47. yd := fWrapped(d, n == 1)
  48. for i := 0; i < n-1; i++ {
  49. if logger != nil {
  50. logger.Printf("%d\t%9.6g\t%9.6g\n", i, a, b)
  51. }
  52. if yc < yd {
  53. b = d
  54. d = c
  55. yd = yc
  56. h = invphi * h
  57. c = a + invphi2*h
  58. yc = fWrapped(c, i == n-2)
  59. } else {
  60. a = c
  61. c = d
  62. yc = yd
  63. h = invphi * h
  64. d = a + invphi*h
  65. yd = fWrapped(d, i == n-2)
  66. }
  67. }
  68. if yc < yd {
  69. return a, d
  70. } else {
  71. return c, b
  72. }
  73. //return gss(f, a, b, tol, nan, nan, nan, nan, nan, logger)
  74. }
  75. func gss(f func(float64) float64, a, b, tol, h, c, d, fc, fd float64, logger *log.Logger) (float64, float64) {
  76. if a > b {
  77. a, b = b, a
  78. }
  79. h = b - a
  80. it := 0
  81. for {
  82. if logger != nil {
  83. logger.Printf("%d\t%9.6g\t%9.6g\n", it, a, b)
  84. }
  85. it++
  86. if h < tol {
  87. return a, b
  88. }
  89. if a > b {
  90. a, b = b, a
  91. }
  92. if math.IsNaN(c) {
  93. c = a + invphi2*h
  94. fc = f(c)
  95. }
  96. if math.IsNaN(d) {
  97. d = a + invphi*h
  98. fd = f(d)
  99. }
  100. if fc < fd {
  101. b, h, c, fc, d, fd = d, h*invphi, nan, nan, c, fc
  102. } else {
  103. a, h, c, fc, d, fd = c, h*invphi, d, fd, nan, nan
  104. }
  105. }
  106. }