gss.go 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. package gss
  2. import (
  3. "log"
  4. "math"
  5. )
  6. var (
  7. sqrt5 = math.Sqrt(5)
  8. invphi = (sqrt5 - 1) / 2 //# 1/phi
  9. invphi2 = (3 - sqrt5) / 2 //# 1/phi^2
  10. nan = math.NaN()
  11. )
  12. // Gss golden section search (recursive version)
  13. // https://en.wikipedia.org/wiki/Golden-section_search
  14. // https://github.com/pa-m/optimize/blob/master/gss.go
  15. // '''
  16. // Golden section search, recursive.
  17. // Given a function f with a single local minimum in
  18. // the interval [a,b], gss returns a subset interval
  19. // [c,d] that contains the minimum with d-c <= tol.
  20. //
  21. // logger may be nil
  22. //
  23. // example:
  24. // >>> f = lambda x: (x-2)**2
  25. // >>> a = 1
  26. // >>> b = 5
  27. // >>> tol = 1e-5
  28. // >>> (c,d) = gssrec(f, a, b, tol)
  29. // >>> print (c,d)
  30. // (1.9999959837979107, 2.0000050911830893)
  31. // '''
  32. func Gss(f func(float64) float64, a, b, tol float64, logger *log.Logger) (float64, float64) {
  33. return gss(f, a, b, tol, nan, nan, nan, nan, nan, logger)
  34. }
  35. func gss(f func(float64) float64, a, b, tol, h, c, d, fc, fd float64, logger *log.Logger) (float64, float64) {
  36. if a > b {
  37. a, b = b, a
  38. }
  39. h = b - a
  40. it := 0
  41. for {
  42. if logger != nil {
  43. logger.Printf("%d\t%9.6g\t%9.6g\n", it, a, b)
  44. }
  45. it++
  46. if h < tol {
  47. return a, b
  48. }
  49. if a > b {
  50. a, b = b, a
  51. }
  52. if math.IsNaN(c) {
  53. c = a + invphi2*h
  54. fc = f(c)
  55. }
  56. if math.IsNaN(d) {
  57. d = a + invphi*h
  58. fd = f(d)
  59. }
  60. if fc < fd {
  61. b, h, c, fc, d, fd = d, h*invphi, nan, nan, c, fc
  62. } else {
  63. a, h, c, fc, d, fd = c, h*invphi, d, fd, nan, nan
  64. }
  65. }
  66. }