main.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. //go:build !plan9
  4. // The sync-containers command synchronizes container image tags from one
  5. // registry to another.
  6. //
  7. // It is intended as a workaround for ghcr.io's lack of good push credentials:
  8. // you can either authorize "classic" Personal Access Tokens in your org (which
  9. // are a common vector of very bad compromise), or you can get a short-lived
  10. // credential in a Github action.
  11. //
  12. // Since we publish to both Docker Hub and ghcr.io, we use this program in a
  13. // Github action to effectively rsync from docker hub into ghcr.io, so that we
  14. // can continue to forbid dangerous Personal Access Tokens in the tailscale org.
  15. package main
  16. import (
  17. "context"
  18. "flag"
  19. "fmt"
  20. "log"
  21. "sort"
  22. "strings"
  23. "github.com/google/go-containerregistry/pkg/authn"
  24. "github.com/google/go-containerregistry/pkg/authn/github"
  25. "github.com/google/go-containerregistry/pkg/name"
  26. v1 "github.com/google/go-containerregistry/pkg/v1"
  27. "github.com/google/go-containerregistry/pkg/v1/remote"
  28. "github.com/google/go-containerregistry/pkg/v1/types"
  29. )
  30. var (
  31. src = flag.String("src", "", "Source image")
  32. dst = flag.String("dst", "", "Destination image")
  33. max = flag.Int("max", 0, "Maximum number of tags to sync (0 for all tags)")
  34. dryRun = flag.Bool("dry-run", true, "Don't actually sync anything")
  35. )
  36. func main() {
  37. flag.Parse()
  38. if *src == "" {
  39. log.Fatalf("--src is required")
  40. }
  41. if *dst == "" {
  42. log.Fatalf("--dst is required")
  43. }
  44. keychain := authn.NewMultiKeychain(authn.DefaultKeychain, github.Keychain)
  45. opts := []remote.Option{
  46. remote.WithAuthFromKeychain(keychain),
  47. remote.WithContext(context.Background()),
  48. }
  49. stags, err := listTags(*src, opts...)
  50. if err != nil {
  51. log.Fatalf("listing source tags: %v", err)
  52. }
  53. dtags, err := listTags(*dst, opts...)
  54. if err != nil {
  55. log.Fatalf("listing destination tags: %v", err)
  56. }
  57. add, remove := diffTags(stags, dtags)
  58. if l := len(add); l > 0 {
  59. log.Printf("%d tags to push: %s", len(add), strings.Join(add, ", "))
  60. if *max > 0 && l > *max {
  61. log.Printf("Limiting sync to %d tags", *max)
  62. add = add[:*max]
  63. }
  64. }
  65. for _, tag := range add {
  66. if !*dryRun {
  67. log.Printf("Syncing tag %q", tag)
  68. if err := copyTag(*src, *dst, tag, opts...); err != nil {
  69. log.Printf("Syncing tag %q: progress error: %v", tag, err)
  70. }
  71. } else {
  72. log.Printf("Dry run: would sync tag %q", tag)
  73. }
  74. }
  75. if len(remove) > 0 {
  76. log.Printf("%d tags to remove: %s\n", len(remove), strings.Join(remove, ", "))
  77. log.Printf("Not removing any tags for safety.\n")
  78. }
  79. var wellKnown = [...]string{"latest", "stable"}
  80. for _, tag := range wellKnown {
  81. if needsUpdate(*src, *dst, tag) {
  82. if err := copyTag(*src, *dst, tag, opts...); err != nil {
  83. log.Printf("Updating tag %q: progress error: %v", tag, err)
  84. }
  85. }
  86. }
  87. }
  88. func copyTag(srcStr, dstStr, tag string, opts ...remote.Option) error {
  89. src, err := name.ParseReference(fmt.Sprintf("%s:%s", srcStr, tag))
  90. if err != nil {
  91. return err
  92. }
  93. dst, err := name.ParseReference(fmt.Sprintf("%s:%s", dstStr, tag))
  94. if err != nil {
  95. return err
  96. }
  97. desc, err := remote.Get(src)
  98. if err != nil {
  99. return err
  100. }
  101. ch := make(chan v1.Update, 10)
  102. opts = append(opts, remote.WithProgress(ch))
  103. progressDone := make(chan struct{})
  104. go func() {
  105. defer close(progressDone)
  106. for p := range ch {
  107. fmt.Printf("Syncing tag %q: %d%% (%d/%d)\n", tag, int(float64(p.Complete)/float64(p.Total)*100), p.Complete, p.Total)
  108. if p.Error != nil {
  109. fmt.Printf("error: %v\n", p.Error)
  110. }
  111. }
  112. }()
  113. switch desc.MediaType {
  114. case types.OCIManifestSchema1, types.DockerManifestSchema2:
  115. img, err := desc.Image()
  116. if err != nil {
  117. return err
  118. }
  119. if err := remote.Write(dst, img, opts...); err != nil {
  120. return err
  121. }
  122. case types.OCIImageIndex, types.DockerManifestList:
  123. idx, err := desc.ImageIndex()
  124. if err != nil {
  125. return err
  126. }
  127. if err := remote.WriteIndex(dst, idx, opts...); err != nil {
  128. return err
  129. }
  130. }
  131. <-progressDone
  132. return nil
  133. }
  134. func listTags(repoStr string, opts ...remote.Option) ([]string, error) {
  135. repo, err := name.NewRepository(repoStr)
  136. if err != nil {
  137. return nil, err
  138. }
  139. tags, err := remote.List(repo, opts...)
  140. if err != nil {
  141. return nil, err
  142. }
  143. sort.Strings(tags)
  144. return tags, nil
  145. }
  146. func diffTags(src, dst []string) (add, remove []string) {
  147. srcd := make(map[string]bool)
  148. for _, tag := range src {
  149. srcd[tag] = true
  150. }
  151. dstd := make(map[string]bool)
  152. for _, tag := range dst {
  153. dstd[tag] = true
  154. }
  155. for _, tag := range src {
  156. if !dstd[tag] {
  157. add = append(add, tag)
  158. }
  159. }
  160. for _, tag := range dst {
  161. if !srcd[tag] {
  162. remove = append(remove, tag)
  163. }
  164. }
  165. sort.Strings(add)
  166. sort.Strings(remove)
  167. return add, remove
  168. }
  169. func needsUpdate(srcStr, dstStr, tag string) bool {
  170. src, err := name.ParseReference(fmt.Sprintf("%s:%s", srcStr, tag))
  171. if err != nil {
  172. return false
  173. }
  174. dst, err := name.ParseReference(fmt.Sprintf("%s:%s", dstStr, tag))
  175. if err != nil {
  176. return false
  177. }
  178. srcDesc, err := remote.Get(src)
  179. if err != nil {
  180. return false
  181. }
  182. dstDesc, err := remote.Get(dst)
  183. if err != nil {
  184. return true
  185. }
  186. return srcDesc.Digest != dstDesc.Digest
  187. }