conn.go 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847
  1. package pq
  2. import (
  3. "bufio"
  4. "crypto/md5"
  5. "crypto/tls"
  6. "crypto/x509"
  7. "database/sql"
  8. "database/sql/driver"
  9. "encoding/binary"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "io/ioutil"
  14. "net"
  15. "os"
  16. "os/user"
  17. "path"
  18. "path/filepath"
  19. "strconv"
  20. "strings"
  21. "time"
  22. "unicode"
  23. "github.com/lib/pq/oid"
  24. )
  25. // Common error types
  26. var (
  27. ErrNotSupported = errors.New("pq: Unsupported command")
  28. ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction")
  29. ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
  30. ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.")
  31. ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.")
  32. )
  33. type drv struct{}
  34. func (d *drv) Open(name string) (driver.Conn, error) {
  35. return Open(name)
  36. }
  37. func init() {
  38. sql.Register("postgres", &drv{})
  39. }
  40. type parameterStatus struct {
  41. // server version in the same format as server_version_num, or 0 if
  42. // unavailable
  43. serverVersion int
  44. // the current location based on the TimeZone value of the session, if
  45. // available
  46. currentLocation *time.Location
  47. }
  48. type transactionStatus byte
  49. const (
  50. txnStatusIdle transactionStatus = 'I'
  51. txnStatusIdleInTransaction transactionStatus = 'T'
  52. txnStatusInFailedTransaction transactionStatus = 'E'
  53. )
  54. func (s transactionStatus) String() string {
  55. switch s {
  56. case txnStatusIdle:
  57. return "idle"
  58. case txnStatusIdleInTransaction:
  59. return "idle in transaction"
  60. case txnStatusInFailedTransaction:
  61. return "in a failed transaction"
  62. default:
  63. errorf("unknown transactionStatus %d", s)
  64. }
  65. panic("not reached")
  66. }
  67. type Dialer interface {
  68. Dial(network, address string) (net.Conn, error)
  69. DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
  70. }
  71. type defaultDialer struct{}
  72. func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) {
  73. return net.Dial(ntw, addr)
  74. }
  75. func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
  76. return net.DialTimeout(ntw, addr, timeout)
  77. }
  78. type conn struct {
  79. c net.Conn
  80. buf *bufio.Reader
  81. namei int
  82. scratch [512]byte
  83. txnStatus transactionStatus
  84. parameterStatus parameterStatus
  85. saveMessageType byte
  86. saveMessageBuffer []byte
  87. // If true, this connection is bad and all public-facing functions should
  88. // return ErrBadConn.
  89. bad bool
  90. // If set, this connection should never use the binary format when
  91. // receiving query results from prepared statements. Only provided for
  92. // debugging.
  93. disablePreparedBinaryResult bool
  94. // Whether to always send []byte parameters over as binary. Enables single
  95. // round-trip mode for non-prepared Query calls.
  96. binaryParameters bool
  97. }
  98. // Handle driver-side settings in parsed connection string.
  99. func (c *conn) handleDriverSettings(o values) (err error) {
  100. boolSetting := func(key string, val *bool) error {
  101. if value := o.Get(key); value != "" {
  102. if value == "yes" {
  103. *val = true
  104. } else if value == "no" {
  105. *val = false
  106. } else {
  107. return fmt.Errorf("unrecognized value %q for %s", value, key)
  108. }
  109. }
  110. return nil
  111. }
  112. err = boolSetting("disable_prepared_binary_result", &c.disablePreparedBinaryResult)
  113. if err != nil {
  114. return err
  115. }
  116. err = boolSetting("binary_parameters", &c.binaryParameters)
  117. if err != nil {
  118. return err
  119. }
  120. return nil
  121. }
  122. func (c *conn) handlePgpass(o values) {
  123. // if a password was supplied, do not process .pgpass
  124. _, ok := o["password"]
  125. if ok {
  126. return
  127. }
  128. filename := os.Getenv("PGPASSFILE")
  129. if filename == "" {
  130. // XXX this code doesn't work on Windows where the default filename is
  131. // XXX %APPDATA%\postgresql\pgpass.conf
  132. user, err := user.Current()
  133. if err != nil {
  134. return
  135. }
  136. filename = filepath.Join(user.HomeDir, ".pgpass")
  137. }
  138. fileinfo, err := os.Stat(filename)
  139. if err != nil {
  140. return
  141. }
  142. mode := fileinfo.Mode()
  143. if mode&(0x77) != 0 {
  144. // XXX should warn about incorrect .pgpass permissions as psql does
  145. return
  146. }
  147. file, err := os.Open(filename)
  148. if err != nil {
  149. return
  150. }
  151. defer file.Close()
  152. scanner := bufio.NewScanner(io.Reader(file))
  153. hostname := o.Get("host")
  154. ntw, _ := network(o)
  155. port := o.Get("port")
  156. db := o.Get("dbname")
  157. username := o.Get("user")
  158. // From: https://github.com/tg/pgpass/blob/master/reader.go
  159. getFields := func(s string) []string {
  160. fs := make([]string, 0, 5)
  161. f := make([]rune, 0, len(s))
  162. var esc bool
  163. for _, c := range s {
  164. switch {
  165. case esc:
  166. f = append(f, c)
  167. esc = false
  168. case c == '\\':
  169. esc = true
  170. case c == ':':
  171. fs = append(fs, string(f))
  172. f = f[:0]
  173. default:
  174. f = append(f, c)
  175. }
  176. }
  177. return append(fs, string(f))
  178. }
  179. for scanner.Scan() {
  180. line := scanner.Text()
  181. if len(line) == 0 || line[0] == '#' {
  182. continue
  183. }
  184. split := getFields(line)
  185. if len(split) != 5 {
  186. continue
  187. }
  188. if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
  189. o["password"] = split[4]
  190. return
  191. }
  192. }
  193. }
  194. func (c *conn) writeBuf(b byte) *writeBuf {
  195. c.scratch[0] = b
  196. return &writeBuf{
  197. buf: c.scratch[:5],
  198. pos: 1,
  199. }
  200. }
  201. func Open(name string) (_ driver.Conn, err error) {
  202. return DialOpen(defaultDialer{}, name)
  203. }
  204. func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
  205. // Handle any panics during connection initialization. Note that we
  206. // specifically do *not* want to use errRecover(), as that would turn any
  207. // connection errors into ErrBadConns, hiding the real error message from
  208. // the user.
  209. defer errRecoverNoErrBadConn(&err)
  210. o := make(values)
  211. // A number of defaults are applied here, in this order:
  212. //
  213. // * Very low precedence defaults applied in every situation
  214. // * Environment variables
  215. // * Explicitly passed connection information
  216. o.Set("host", "localhost")
  217. o.Set("port", "5432")
  218. // N.B.: Extra float digits should be set to 3, but that breaks
  219. // Postgres 8.4 and older, where the max is 2.
  220. o.Set("extra_float_digits", "2")
  221. for k, v := range parseEnviron(os.Environ()) {
  222. o.Set(k, v)
  223. }
  224. if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") {
  225. name, err = ParseURL(name)
  226. if err != nil {
  227. return nil, err
  228. }
  229. }
  230. if err := parseOpts(name, o); err != nil {
  231. return nil, err
  232. }
  233. // Use the "fallback" application name if necessary
  234. if fallback := o.Get("fallback_application_name"); fallback != "" {
  235. if !o.Isset("application_name") {
  236. o.Set("application_name", fallback)
  237. }
  238. }
  239. // We can't work with any client_encoding other than UTF-8 currently.
  240. // However, we have historically allowed the user to set it to UTF-8
  241. // explicitly, and there's no reason to break such programs, so allow that.
  242. // Note that the "options" setting could also set client_encoding, but
  243. // parsing its value is not worth it. Instead, we always explicitly send
  244. // client_encoding as a separate run-time parameter, which should override
  245. // anything set in options.
  246. if enc := o.Get("client_encoding"); enc != "" && !isUTF8(enc) {
  247. return nil, errors.New("client_encoding must be absent or 'UTF8'")
  248. }
  249. o.Set("client_encoding", "UTF8")
  250. // DateStyle needs a similar treatment.
  251. if datestyle := o.Get("datestyle"); datestyle != "" {
  252. if datestyle != "ISO, MDY" {
  253. panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v",
  254. "ISO, MDY", datestyle))
  255. }
  256. } else {
  257. o.Set("datestyle", "ISO, MDY")
  258. }
  259. // If a user is not provided by any other means, the last
  260. // resort is to use the current operating system provided user
  261. // name.
  262. if o.Get("user") == "" {
  263. u, err := userCurrent()
  264. if err != nil {
  265. return nil, err
  266. } else {
  267. o.Set("user", u)
  268. }
  269. }
  270. cn := &conn{}
  271. err = cn.handleDriverSettings(o)
  272. if err != nil {
  273. return nil, err
  274. }
  275. cn.handlePgpass(o)
  276. cn.c, err = dial(d, o)
  277. if err != nil {
  278. return nil, err
  279. }
  280. cn.ssl(o)
  281. cn.buf = bufio.NewReader(cn.c)
  282. cn.startup(o)
  283. // reset the deadline, in case one was set (see dial)
  284. if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" {
  285. err = cn.c.SetDeadline(time.Time{})
  286. }
  287. return cn, err
  288. }
  289. func dial(d Dialer, o values) (net.Conn, error) {
  290. ntw, addr := network(o)
  291. // SSL is not necessary or supported over UNIX domain sockets
  292. if ntw == "unix" {
  293. o["sslmode"] = "disable"
  294. }
  295. // Zero or not specified means wait indefinitely.
  296. if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" {
  297. seconds, err := strconv.ParseInt(timeout, 10, 0)
  298. if err != nil {
  299. return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
  300. }
  301. duration := time.Duration(seconds) * time.Second
  302. // connect_timeout should apply to the entire connection establishment
  303. // procedure, so we both use a timeout for the TCP connection
  304. // establishment and set a deadline for doing the initial handshake.
  305. // The deadline is then reset after startup() is done.
  306. deadline := time.Now().Add(duration)
  307. conn, err := d.DialTimeout(ntw, addr, duration)
  308. if err != nil {
  309. return nil, err
  310. }
  311. err = conn.SetDeadline(deadline)
  312. return conn, err
  313. }
  314. return d.Dial(ntw, addr)
  315. }
  316. func network(o values) (string, string) {
  317. host := o.Get("host")
  318. if strings.HasPrefix(host, "/") {
  319. sockPath := path.Join(host, ".s.PGSQL."+o.Get("port"))
  320. return "unix", sockPath
  321. }
  322. return "tcp", net.JoinHostPort(host, o.Get("port"))
  323. }
  324. type values map[string]string
  325. func (vs values) Set(k, v string) {
  326. vs[k] = v
  327. }
  328. func (vs values) Get(k string) (v string) {
  329. return vs[k]
  330. }
  331. func (vs values) Isset(k string) bool {
  332. _, ok := vs[k]
  333. return ok
  334. }
  335. // scanner implements a tokenizer for libpq-style option strings.
  336. type scanner struct {
  337. s []rune
  338. i int
  339. }
  340. // newScanner returns a new scanner initialized with the option string s.
  341. func newScanner(s string) *scanner {
  342. return &scanner{[]rune(s), 0}
  343. }
  344. // Next returns the next rune.
  345. // It returns 0, false if the end of the text has been reached.
  346. func (s *scanner) Next() (rune, bool) {
  347. if s.i >= len(s.s) {
  348. return 0, false
  349. }
  350. r := s.s[s.i]
  351. s.i++
  352. return r, true
  353. }
  354. // SkipSpaces returns the next non-whitespace rune.
  355. // It returns 0, false if the end of the text has been reached.
  356. func (s *scanner) SkipSpaces() (rune, bool) {
  357. r, ok := s.Next()
  358. for unicode.IsSpace(r) && ok {
  359. r, ok = s.Next()
  360. }
  361. return r, ok
  362. }
  363. // parseOpts parses the options from name and adds them to the values.
  364. //
  365. // The parsing code is based on conninfo_parse from libpq's fe-connect.c
  366. func parseOpts(name string, o values) error {
  367. s := newScanner(name)
  368. for {
  369. var (
  370. keyRunes, valRunes []rune
  371. r rune
  372. ok bool
  373. )
  374. if r, ok = s.SkipSpaces(); !ok {
  375. break
  376. }
  377. // Scan the key
  378. for !unicode.IsSpace(r) && r != '=' {
  379. keyRunes = append(keyRunes, r)
  380. if r, ok = s.Next(); !ok {
  381. break
  382. }
  383. }
  384. // Skip any whitespace if we're not at the = yet
  385. if r != '=' {
  386. r, ok = s.SkipSpaces()
  387. }
  388. // The current character should be =
  389. if r != '=' || !ok {
  390. return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
  391. }
  392. // Skip any whitespace after the =
  393. if r, ok = s.SkipSpaces(); !ok {
  394. // If we reach the end here, the last value is just an empty string as per libpq.
  395. o.Set(string(keyRunes), "")
  396. break
  397. }
  398. if r != '\'' {
  399. for !unicode.IsSpace(r) {
  400. if r == '\\' {
  401. if r, ok = s.Next(); !ok {
  402. return fmt.Errorf(`missing character after backslash`)
  403. }
  404. }
  405. valRunes = append(valRunes, r)
  406. if r, ok = s.Next(); !ok {
  407. break
  408. }
  409. }
  410. } else {
  411. quote:
  412. for {
  413. if r, ok = s.Next(); !ok {
  414. return fmt.Errorf(`unterminated quoted string literal in connection string`)
  415. }
  416. switch r {
  417. case '\'':
  418. break quote
  419. case '\\':
  420. r, _ = s.Next()
  421. fallthrough
  422. default:
  423. valRunes = append(valRunes, r)
  424. }
  425. }
  426. }
  427. o.Set(string(keyRunes), string(valRunes))
  428. }
  429. return nil
  430. }
  431. func (cn *conn) isInTransaction() bool {
  432. return cn.txnStatus == txnStatusIdleInTransaction ||
  433. cn.txnStatus == txnStatusInFailedTransaction
  434. }
  435. func (cn *conn) checkIsInTransaction(intxn bool) {
  436. if cn.isInTransaction() != intxn {
  437. cn.bad = true
  438. errorf("unexpected transaction status %v", cn.txnStatus)
  439. }
  440. }
  441. func (cn *conn) Begin() (_ driver.Tx, err error) {
  442. if cn.bad {
  443. return nil, driver.ErrBadConn
  444. }
  445. defer cn.errRecover(&err)
  446. cn.checkIsInTransaction(false)
  447. _, commandTag, err := cn.simpleExec("BEGIN")
  448. if err != nil {
  449. return nil, err
  450. }
  451. if commandTag != "BEGIN" {
  452. cn.bad = true
  453. return nil, fmt.Errorf("unexpected command tag %s", commandTag)
  454. }
  455. if cn.txnStatus != txnStatusIdleInTransaction {
  456. cn.bad = true
  457. return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
  458. }
  459. return cn, nil
  460. }
  461. func (cn *conn) Commit() (err error) {
  462. if cn.bad {
  463. return driver.ErrBadConn
  464. }
  465. defer cn.errRecover(&err)
  466. cn.checkIsInTransaction(true)
  467. // We don't want the client to think that everything is okay if it tries
  468. // to commit a failed transaction. However, no matter what we return,
  469. // database/sql will release this connection back into the free connection
  470. // pool so we have to abort the current transaction here. Note that you
  471. // would get the same behaviour if you issued a COMMIT in a failed
  472. // transaction, so it's also the least surprising thing to do here.
  473. if cn.txnStatus == txnStatusInFailedTransaction {
  474. if err := cn.Rollback(); err != nil {
  475. return err
  476. }
  477. return ErrInFailedTransaction
  478. }
  479. _, commandTag, err := cn.simpleExec("COMMIT")
  480. if err != nil {
  481. if cn.isInTransaction() {
  482. cn.bad = true
  483. }
  484. return err
  485. }
  486. if commandTag != "COMMIT" {
  487. cn.bad = true
  488. return fmt.Errorf("unexpected command tag %s", commandTag)
  489. }
  490. cn.checkIsInTransaction(false)
  491. return nil
  492. }
  493. func (cn *conn) Rollback() (err error) {
  494. if cn.bad {
  495. return driver.ErrBadConn
  496. }
  497. defer cn.errRecover(&err)
  498. cn.checkIsInTransaction(true)
  499. _, commandTag, err := cn.simpleExec("ROLLBACK")
  500. if err != nil {
  501. if cn.isInTransaction() {
  502. cn.bad = true
  503. }
  504. return err
  505. }
  506. if commandTag != "ROLLBACK" {
  507. return fmt.Errorf("unexpected command tag %s", commandTag)
  508. }
  509. cn.checkIsInTransaction(false)
  510. return nil
  511. }
  512. func (cn *conn) gname() string {
  513. cn.namei++
  514. return strconv.FormatInt(int64(cn.namei), 10)
  515. }
  516. func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
  517. b := cn.writeBuf('Q')
  518. b.string(q)
  519. cn.send(b)
  520. for {
  521. t, r := cn.recv1()
  522. switch t {
  523. case 'C':
  524. res, commandTag = cn.parseComplete(r.string())
  525. case 'Z':
  526. cn.processReadyForQuery(r)
  527. // done
  528. return
  529. case 'E':
  530. err = parseError(r)
  531. case 'T', 'D', 'I':
  532. // ignore any results
  533. default:
  534. cn.bad = true
  535. errorf("unknown response for simple query: %q", t)
  536. }
  537. }
  538. }
  539. func (cn *conn) simpleQuery(q string) (res *rows, err error) {
  540. defer cn.errRecover(&err)
  541. b := cn.writeBuf('Q')
  542. b.string(q)
  543. cn.send(b)
  544. for {
  545. t, r := cn.recv1()
  546. switch t {
  547. case 'C', 'I':
  548. // We allow queries which don't return any results through Query as
  549. // well as Exec. We still have to give database/sql a rows object
  550. // the user can close, though, to avoid connections from being
  551. // leaked. A "rows" with done=true works fine for that purpose.
  552. if err != nil {
  553. cn.bad = true
  554. errorf("unexpected message %q in simple query execution", t)
  555. }
  556. if res == nil {
  557. res = &rows{
  558. cn: cn,
  559. }
  560. }
  561. res.done = true
  562. case 'Z':
  563. cn.processReadyForQuery(r)
  564. // done
  565. return
  566. case 'E':
  567. res = nil
  568. err = parseError(r)
  569. case 'D':
  570. if res == nil {
  571. cn.bad = true
  572. errorf("unexpected DataRow in simple query execution")
  573. }
  574. // the query didn't fail; kick off to Next
  575. cn.saveMessage(t, r)
  576. return
  577. case 'T':
  578. // res might be non-nil here if we received a previous
  579. // CommandComplete, but that's fine; just overwrite it
  580. res = &rows{cn: cn}
  581. res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)
  582. // To work around a bug in QueryRow in Go 1.2 and earlier, wait
  583. // until the first DataRow has been received.
  584. default:
  585. cn.bad = true
  586. errorf("unknown response for simple query: %q", t)
  587. }
  588. }
  589. }
  590. // Decides which column formats to use for a prepared statement. The input is
  591. // an array of type oids, one element per result column.
  592. func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, colFmtData []byte) {
  593. if len(colTyps) == 0 {
  594. return nil, colFmtDataAllText
  595. }
  596. colFmts = make([]format, len(colTyps))
  597. if forceText {
  598. return colFmts, colFmtDataAllText
  599. }
  600. allBinary := true
  601. allText := true
  602. for i, o := range colTyps {
  603. switch o {
  604. // This is the list of types to use binary mode for when receiving them
  605. // through a prepared statement. If a type appears in this list, it
  606. // must also be implemented in binaryDecode in encode.go.
  607. case oid.T_bytea:
  608. fallthrough
  609. case oid.T_int8:
  610. fallthrough
  611. case oid.T_int4:
  612. fallthrough
  613. case oid.T_int2:
  614. colFmts[i] = formatBinary
  615. allText = false
  616. default:
  617. allBinary = false
  618. }
  619. }
  620. if allBinary {
  621. return colFmts, colFmtDataAllBinary
  622. } else if allText {
  623. return colFmts, colFmtDataAllText
  624. } else {
  625. colFmtData = make([]byte, 2+len(colFmts)*2)
  626. binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
  627. for i, v := range colFmts {
  628. binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
  629. }
  630. return colFmts, colFmtData
  631. }
  632. }
  633. func (cn *conn) prepareTo(q, stmtName string) *stmt {
  634. st := &stmt{cn: cn, name: stmtName}
  635. b := cn.writeBuf('P')
  636. b.string(st.name)
  637. b.string(q)
  638. b.int16(0)
  639. b.next('D')
  640. b.byte('S')
  641. b.string(st.name)
  642. b.next('S')
  643. cn.send(b)
  644. cn.readParseResponse()
  645. st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
  646. st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
  647. cn.readReadyForQuery()
  648. return st
  649. }
  650. func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
  651. if cn.bad {
  652. return nil, driver.ErrBadConn
  653. }
  654. defer cn.errRecover(&err)
  655. if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
  656. return cn.prepareCopyIn(q)
  657. }
  658. return cn.prepareTo(q, cn.gname()), nil
  659. }
  660. func (cn *conn) Close() (err error) {
  661. if cn.bad {
  662. return driver.ErrBadConn
  663. }
  664. defer cn.errRecover(&err)
  665. // Don't go through send(); ListenerConn relies on us not scribbling on the
  666. // scratch buffer of this connection.
  667. err = cn.sendSimpleMessage('X')
  668. if err != nil {
  669. return err
  670. }
  671. return cn.c.Close()
  672. }
  673. // Implement the "Queryer" interface
  674. func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err error) {
  675. if cn.bad {
  676. return nil, driver.ErrBadConn
  677. }
  678. defer cn.errRecover(&err)
  679. // Check to see if we can use the "simpleQuery" interface, which is
  680. // *much* faster than going through prepare/exec
  681. if len(args) == 0 {
  682. return cn.simpleQuery(query)
  683. }
  684. if cn.binaryParameters {
  685. cn.sendBinaryModeQuery(query, args)
  686. cn.readParseResponse()
  687. cn.readBindResponse()
  688. rows := &rows{cn: cn}
  689. rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse()
  690. cn.postExecuteWorkaround()
  691. return rows, nil
  692. } else {
  693. st := cn.prepareTo(query, "")
  694. st.exec(args)
  695. return &rows{
  696. cn: cn,
  697. colNames: st.colNames,
  698. colTyps: st.colTyps,
  699. colFmts: st.colFmts,
  700. }, nil
  701. }
  702. }
  703. // Implement the optional "Execer" interface for one-shot queries
  704. func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
  705. if cn.bad {
  706. return nil, driver.ErrBadConn
  707. }
  708. defer cn.errRecover(&err)
  709. // Check to see if we can use the "simpleExec" interface, which is
  710. // *much* faster than going through prepare/exec
  711. if len(args) == 0 {
  712. // ignore commandTag, our caller doesn't care
  713. r, _, err := cn.simpleExec(query)
  714. return r, err
  715. }
  716. if cn.binaryParameters {
  717. cn.sendBinaryModeQuery(query, args)
  718. cn.readParseResponse()
  719. cn.readBindResponse()
  720. cn.readPortalDescribeResponse()
  721. cn.postExecuteWorkaround()
  722. res, _, err = cn.readExecuteResponse("Execute")
  723. return res, err
  724. } else {
  725. // Use the unnamed statement to defer planning until bind
  726. // time, or else value-based selectivity estimates cannot be
  727. // used.
  728. st := cn.prepareTo(query, "")
  729. r, err := st.Exec(args)
  730. if err != nil {
  731. panic(err)
  732. }
  733. return r, err
  734. }
  735. }
  736. func (cn *conn) send(m *writeBuf) {
  737. _, err := cn.c.Write(m.wrap())
  738. if err != nil {
  739. panic(err)
  740. }
  741. }
  742. func (cn *conn) sendStartupPacket(m *writeBuf) {
  743. // sanity check
  744. if m.buf[0] != 0 {
  745. panic("oops")
  746. }
  747. _, err := cn.c.Write((m.wrap())[1:])
  748. if err != nil {
  749. panic(err)
  750. }
  751. }
  752. // Send a message of type typ to the server on the other end of cn. The
  753. // message should have no payload. This method does not use the scratch
  754. // buffer.
  755. func (cn *conn) sendSimpleMessage(typ byte) (err error) {
  756. _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
  757. return err
  758. }
  759. // saveMessage memorizes a message and its buffer in the conn struct.
  760. // recvMessage will then return these values on the next call to it. This
  761. // method is useful in cases where you have to see what the next message is
  762. // going to be (e.g. to see whether it's an error or not) but you can't handle
  763. // the message yourself.
  764. func (cn *conn) saveMessage(typ byte, buf *readBuf) {
  765. if cn.saveMessageType != 0 {
  766. cn.bad = true
  767. errorf("unexpected saveMessageType %d", cn.saveMessageType)
  768. }
  769. cn.saveMessageType = typ
  770. cn.saveMessageBuffer = *buf
  771. }
  772. // recvMessage receives any message from the backend, or returns an error if
  773. // a problem occurred while reading the message.
  774. func (cn *conn) recvMessage(r *readBuf) (byte, error) {
  775. // workaround for a QueryRow bug, see exec
  776. if cn.saveMessageType != 0 {
  777. t := cn.saveMessageType
  778. *r = cn.saveMessageBuffer
  779. cn.saveMessageType = 0
  780. cn.saveMessageBuffer = nil
  781. return t, nil
  782. }
  783. x := cn.scratch[:5]
  784. _, err := io.ReadFull(cn.buf, x)
  785. if err != nil {
  786. return 0, err
  787. }
  788. // read the type and length of the message that follows
  789. t := x[0]
  790. n := int(binary.BigEndian.Uint32(x[1:])) - 4
  791. var y []byte
  792. if n <= len(cn.scratch) {
  793. y = cn.scratch[:n]
  794. } else {
  795. y = make([]byte, n)
  796. }
  797. _, err = io.ReadFull(cn.buf, y)
  798. if err != nil {
  799. return 0, err
  800. }
  801. *r = y
  802. return t, nil
  803. }
  804. // recv receives a message from the backend, but if an error happened while
  805. // reading the message or the received message was an ErrorResponse, it panics.
  806. // NoticeResponses are ignored. This function should generally be used only
  807. // during the startup sequence.
  808. func (cn *conn) recv() (t byte, r *readBuf) {
  809. for {
  810. var err error
  811. r = &readBuf{}
  812. t, err = cn.recvMessage(r)
  813. if err != nil {
  814. panic(err)
  815. }
  816. switch t {
  817. case 'E':
  818. panic(parseError(r))
  819. case 'N':
  820. // ignore
  821. default:
  822. return
  823. }
  824. }
  825. }
  826. // recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
  827. // the caller to avoid an allocation.
  828. func (cn *conn) recv1Buf(r *readBuf) byte {
  829. for {
  830. t, err := cn.recvMessage(r)
  831. if err != nil {
  832. panic(err)
  833. }
  834. switch t {
  835. case 'A', 'N':
  836. // ignore
  837. case 'S':
  838. cn.processParameterStatus(r)
  839. default:
  840. return t
  841. }
  842. }
  843. }
  844. // recv1 receives a message from the backend, panicking if an error occurs
  845. // while attempting to read it. All asynchronous messages are ignored, with
  846. // the exception of ErrorResponse.
  847. func (cn *conn) recv1() (t byte, r *readBuf) {
  848. r = &readBuf{}
  849. t = cn.recv1Buf(r)
  850. return t, r
  851. }
  852. func (cn *conn) ssl(o values) {
  853. verifyCaOnly := false
  854. tlsConf := tls.Config{}
  855. switch mode := o.Get("sslmode"); mode {
  856. case "require", "":
  857. tlsConf.InsecureSkipVerify = true
  858. case "verify-ca":
  859. // We must skip TLS's own verification since it requires full
  860. // verification since Go 1.3.
  861. tlsConf.InsecureSkipVerify = true
  862. verifyCaOnly = true
  863. case "verify-full":
  864. tlsConf.ServerName = o.Get("host")
  865. case "disable":
  866. return
  867. default:
  868. errorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode)
  869. }
  870. cn.setupSSLClientCertificates(&tlsConf, o)
  871. cn.setupSSLCA(&tlsConf, o)
  872. w := cn.writeBuf(0)
  873. w.int32(80877103)
  874. cn.sendStartupPacket(w)
  875. b := cn.scratch[:1]
  876. _, err := io.ReadFull(cn.c, b)
  877. if err != nil {
  878. panic(err)
  879. }
  880. if b[0] != 'S' {
  881. panic(ErrSSLNotSupported)
  882. }
  883. client := tls.Client(cn.c, &tlsConf)
  884. if verifyCaOnly {
  885. cn.verifyCA(client, &tlsConf)
  886. }
  887. cn.c = client
  888. }
  889. // verifyCA carries out a TLS handshake to the server and verifies the
  890. // presented certificate against the effective CA, i.e. the one specified in
  891. // sslrootcert or the system CA if sslrootcert was not specified.
  892. func (cn *conn) verifyCA(client *tls.Conn, tlsConf *tls.Config) {
  893. err := client.Handshake()
  894. if err != nil {
  895. panic(err)
  896. }
  897. certs := client.ConnectionState().PeerCertificates
  898. opts := x509.VerifyOptions{
  899. DNSName: client.ConnectionState().ServerName,
  900. Intermediates: x509.NewCertPool(),
  901. Roots: tlsConf.RootCAs,
  902. }
  903. for i, cert := range certs {
  904. if i == 0 {
  905. continue
  906. }
  907. opts.Intermediates.AddCert(cert)
  908. }
  909. _, err = certs[0].Verify(opts)
  910. if err != nil {
  911. panic(err)
  912. }
  913. }
  914. // This function sets up SSL client certificates based on either the "sslkey"
  915. // and "sslcert" settings (possibly set via the environment variables PGSSLKEY
  916. // and PGSSLCERT, respectively), or if they aren't set, from the .postgresql
  917. // directory in the user's home directory. If the file paths are set
  918. // explicitly, the files must exist. The key file must also not be
  919. // world-readable, or this function will panic with
  920. // ErrSSLKeyHasWorldPermissions.
  921. func (cn *conn) setupSSLClientCertificates(tlsConf *tls.Config, o values) {
  922. var missingOk bool
  923. sslkey := o.Get("sslkey")
  924. sslcert := o.Get("sslcert")
  925. if sslkey != "" && sslcert != "" {
  926. // If the user has set an sslkey and sslcert, they *must* exist.
  927. missingOk = false
  928. } else {
  929. // Automatically load certificates from ~/.postgresql.
  930. user, err := user.Current()
  931. if err != nil {
  932. // user.Current() might fail when cross-compiling. We have to
  933. // ignore the error and continue without client certificates, since
  934. // we wouldn't know where to load them from.
  935. return
  936. }
  937. sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key")
  938. sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt")
  939. missingOk = true
  940. }
  941. // Check that both files exist, and report the error or stop, depending on
  942. // which behaviour we want. Note that we don't do any more extensive
  943. // checks than this (such as checking that the paths aren't directories);
  944. // LoadX509KeyPair() will take care of the rest.
  945. keyfinfo, err := os.Stat(sslkey)
  946. if err != nil && missingOk {
  947. return
  948. } else if err != nil {
  949. panic(err)
  950. }
  951. _, err = os.Stat(sslcert)
  952. if err != nil && missingOk {
  953. return
  954. } else if err != nil {
  955. panic(err)
  956. }
  957. // If we got this far, the key file must also have the correct permissions
  958. kmode := keyfinfo.Mode()
  959. if kmode != kmode&0600 {
  960. panic(ErrSSLKeyHasWorldPermissions)
  961. }
  962. cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
  963. if err != nil {
  964. panic(err)
  965. }
  966. tlsConf.Certificates = []tls.Certificate{cert}
  967. }
  968. // Sets up RootCAs in the TLS configuration if sslrootcert is set.
  969. func (cn *conn) setupSSLCA(tlsConf *tls.Config, o values) {
  970. if sslrootcert := o.Get("sslrootcert"); sslrootcert != "" {
  971. tlsConf.RootCAs = x509.NewCertPool()
  972. cert, err := ioutil.ReadFile(sslrootcert)
  973. if err != nil {
  974. panic(err)
  975. }
  976. ok := tlsConf.RootCAs.AppendCertsFromPEM(cert)
  977. if !ok {
  978. errorf("couldn't parse pem in sslrootcert")
  979. }
  980. }
  981. }
  982. // isDriverSetting returns true iff a setting is purely for configuring the
  983. // driver's options and should not be sent to the server in the connection
  984. // startup packet.
  985. func isDriverSetting(key string) bool {
  986. switch key {
  987. case "host", "port":
  988. return true
  989. case "password":
  990. return true
  991. case "sslmode", "sslcert", "sslkey", "sslrootcert":
  992. return true
  993. case "fallback_application_name":
  994. return true
  995. case "connect_timeout":
  996. return true
  997. case "disable_prepared_binary_result":
  998. return true
  999. case "binary_parameters":
  1000. return true
  1001. default:
  1002. return false
  1003. }
  1004. }
  1005. func (cn *conn) startup(o values) {
  1006. w := cn.writeBuf(0)
  1007. w.int32(196608)
  1008. // Send the backend the name of the database we want to connect to, and the
  1009. // user we want to connect as. Additionally, we send over any run-time
  1010. // parameters potentially included in the connection string. If the server
  1011. // doesn't recognize any of them, it will reply with an error.
  1012. for k, v := range o {
  1013. if isDriverSetting(k) {
  1014. // skip options which can't be run-time parameters
  1015. continue
  1016. }
  1017. // The protocol requires us to supply the database name as "database"
  1018. // instead of "dbname".
  1019. if k == "dbname" {
  1020. k = "database"
  1021. }
  1022. w.string(k)
  1023. w.string(v)
  1024. }
  1025. w.string("")
  1026. cn.sendStartupPacket(w)
  1027. for {
  1028. t, r := cn.recv()
  1029. switch t {
  1030. case 'K':
  1031. case 'S':
  1032. cn.processParameterStatus(r)
  1033. case 'R':
  1034. cn.auth(r, o)
  1035. case 'Z':
  1036. cn.processReadyForQuery(r)
  1037. return
  1038. default:
  1039. errorf("unknown response for startup: %q", t)
  1040. }
  1041. }
  1042. }
  1043. func (cn *conn) auth(r *readBuf, o values) {
  1044. switch code := r.int32(); code {
  1045. case 0:
  1046. // OK
  1047. case 3:
  1048. w := cn.writeBuf('p')
  1049. w.string(o.Get("password"))
  1050. cn.send(w)
  1051. t, r := cn.recv()
  1052. if t != 'R' {
  1053. errorf("unexpected password response: %q", t)
  1054. }
  1055. if r.int32() != 0 {
  1056. errorf("unexpected authentication response: %q", t)
  1057. }
  1058. case 5:
  1059. s := string(r.next(4))
  1060. w := cn.writeBuf('p')
  1061. w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s))
  1062. cn.send(w)
  1063. t, r := cn.recv()
  1064. if t != 'R' {
  1065. errorf("unexpected password response: %q", t)
  1066. }
  1067. if r.int32() != 0 {
  1068. errorf("unexpected authentication response: %q", t)
  1069. }
  1070. default:
  1071. errorf("unknown authentication response: %d", code)
  1072. }
  1073. }
  1074. type format int
  1075. const formatText format = 0
  1076. const formatBinary format = 1
  1077. // One result-column format code with the value 1 (i.e. all binary).
  1078. var colFmtDataAllBinary []byte = []byte{0, 1, 0, 1}
  1079. // No result-column format codes (i.e. all text).
  1080. var colFmtDataAllText []byte = []byte{0, 0}
  1081. type stmt struct {
  1082. cn *conn
  1083. name string
  1084. colNames []string
  1085. colFmts []format
  1086. colFmtData []byte
  1087. colTyps []oid.Oid
  1088. paramTyps []oid.Oid
  1089. closed bool
  1090. }
  1091. func (st *stmt) Close() (err error) {
  1092. if st.closed {
  1093. return nil
  1094. }
  1095. if st.cn.bad {
  1096. return driver.ErrBadConn
  1097. }
  1098. defer st.cn.errRecover(&err)
  1099. w := st.cn.writeBuf('C')
  1100. w.byte('S')
  1101. w.string(st.name)
  1102. st.cn.send(w)
  1103. st.cn.send(st.cn.writeBuf('S'))
  1104. t, _ := st.cn.recv1()
  1105. if t != '3' {
  1106. st.cn.bad = true
  1107. errorf("unexpected close response: %q", t)
  1108. }
  1109. st.closed = true
  1110. t, r := st.cn.recv1()
  1111. if t != 'Z' {
  1112. st.cn.bad = true
  1113. errorf("expected ready for query, but got: %q", t)
  1114. }
  1115. st.cn.processReadyForQuery(r)
  1116. return nil
  1117. }
  1118. func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
  1119. if st.cn.bad {
  1120. return nil, driver.ErrBadConn
  1121. }
  1122. defer st.cn.errRecover(&err)
  1123. st.exec(v)
  1124. return &rows{
  1125. cn: st.cn,
  1126. colNames: st.colNames,
  1127. colTyps: st.colTyps,
  1128. colFmts: st.colFmts,
  1129. }, nil
  1130. }
  1131. func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
  1132. if st.cn.bad {
  1133. return nil, driver.ErrBadConn
  1134. }
  1135. defer st.cn.errRecover(&err)
  1136. st.exec(v)
  1137. res, _, err = st.cn.readExecuteResponse("simple query")
  1138. return res, err
  1139. }
  1140. func (st *stmt) exec(v []driver.Value) {
  1141. if len(v) >= 65536 {
  1142. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
  1143. }
  1144. if len(v) != len(st.paramTyps) {
  1145. errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
  1146. }
  1147. cn := st.cn
  1148. w := cn.writeBuf('B')
  1149. w.byte(0) // unnamed portal
  1150. w.string(st.name)
  1151. if cn.binaryParameters {
  1152. cn.sendBinaryParameters(w, v)
  1153. } else {
  1154. w.int16(0)
  1155. w.int16(len(v))
  1156. for i, x := range v {
  1157. if x == nil {
  1158. w.int32(-1)
  1159. } else {
  1160. b := encode(&cn.parameterStatus, x, st.paramTyps[i])
  1161. w.int32(len(b))
  1162. w.bytes(b)
  1163. }
  1164. }
  1165. }
  1166. w.bytes(st.colFmtData)
  1167. w.next('E')
  1168. w.byte(0)
  1169. w.int32(0)
  1170. w.next('S')
  1171. cn.send(w)
  1172. cn.readBindResponse()
  1173. cn.postExecuteWorkaround()
  1174. }
  1175. func (st *stmt) NumInput() int {
  1176. return len(st.paramTyps)
  1177. }
  1178. // parseComplete parses the "command tag" from a CommandComplete message, and
  1179. // returns the number of rows affected (if applicable) and a string
  1180. // identifying only the command that was executed, e.g. "ALTER TABLE". If the
  1181. // command tag could not be parsed, parseComplete panics.
  1182. func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
  1183. commandsWithAffectedRows := []string{
  1184. "SELECT ",
  1185. // INSERT is handled below
  1186. "UPDATE ",
  1187. "DELETE ",
  1188. "FETCH ",
  1189. "MOVE ",
  1190. "COPY ",
  1191. }
  1192. var affectedRows *string
  1193. for _, tag := range commandsWithAffectedRows {
  1194. if strings.HasPrefix(commandTag, tag) {
  1195. t := commandTag[len(tag):]
  1196. affectedRows = &t
  1197. commandTag = tag[:len(tag)-1]
  1198. break
  1199. }
  1200. }
  1201. // INSERT also includes the oid of the inserted row in its command tag.
  1202. // Oids in user tables are deprecated, and the oid is only returned when
  1203. // exactly one row is inserted, so it's unlikely to be of value to any
  1204. // real-world application and we can ignore it.
  1205. if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
  1206. parts := strings.Split(commandTag, " ")
  1207. if len(parts) != 3 {
  1208. cn.bad = true
  1209. errorf("unexpected INSERT command tag %s", commandTag)
  1210. }
  1211. affectedRows = &parts[len(parts)-1]
  1212. commandTag = "INSERT"
  1213. }
  1214. // There should be no affected rows attached to the tag, just return it
  1215. if affectedRows == nil {
  1216. return driver.RowsAffected(0), commandTag
  1217. }
  1218. n, err := strconv.ParseInt(*affectedRows, 10, 64)
  1219. if err != nil {
  1220. cn.bad = true
  1221. errorf("could not parse commandTag: %s", err)
  1222. }
  1223. return driver.RowsAffected(n), commandTag
  1224. }
  1225. type rows struct {
  1226. cn *conn
  1227. colNames []string
  1228. colTyps []oid.Oid
  1229. colFmts []format
  1230. done bool
  1231. rb readBuf
  1232. }
  1233. func (rs *rows) Close() error {
  1234. // no need to look at cn.bad as Next() will
  1235. for {
  1236. err := rs.Next(nil)
  1237. switch err {
  1238. case nil:
  1239. case io.EOF:
  1240. return nil
  1241. default:
  1242. return err
  1243. }
  1244. }
  1245. }
  1246. func (rs *rows) Columns() []string {
  1247. return rs.colNames
  1248. }
  1249. func (rs *rows) Next(dest []driver.Value) (err error) {
  1250. if rs.done {
  1251. return io.EOF
  1252. }
  1253. conn := rs.cn
  1254. if conn.bad {
  1255. return driver.ErrBadConn
  1256. }
  1257. defer conn.errRecover(&err)
  1258. for {
  1259. t := conn.recv1Buf(&rs.rb)
  1260. switch t {
  1261. case 'E':
  1262. err = parseError(&rs.rb)
  1263. case 'C', 'I':
  1264. continue
  1265. case 'Z':
  1266. conn.processReadyForQuery(&rs.rb)
  1267. rs.done = true
  1268. if err != nil {
  1269. return err
  1270. }
  1271. return io.EOF
  1272. case 'D':
  1273. n := rs.rb.int16()
  1274. if err != nil {
  1275. conn.bad = true
  1276. errorf("unexpected DataRow after error %s", err)
  1277. }
  1278. if n < len(dest) {
  1279. dest = dest[:n]
  1280. }
  1281. for i := range dest {
  1282. l := rs.rb.int32()
  1283. if l == -1 {
  1284. dest[i] = nil
  1285. continue
  1286. }
  1287. dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i], rs.colFmts[i])
  1288. }
  1289. return
  1290. default:
  1291. errorf("unexpected message after execute: %q", t)
  1292. }
  1293. }
  1294. }
  1295. // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
  1296. // used as part of an SQL statement. For example:
  1297. //
  1298. // tblname := "my_table"
  1299. // data := "my_data"
  1300. // err = db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", pq.QuoteIdentifier(tblname)), data)
  1301. //
  1302. // Any double quotes in name will be escaped. The quoted identifier will be
  1303. // case sensitive when used in a query. If the input string contains a zero
  1304. // byte, the result will be truncated immediately before it.
  1305. func QuoteIdentifier(name string) string {
  1306. end := strings.IndexRune(name, 0)
  1307. if end > -1 {
  1308. name = name[:end]
  1309. }
  1310. return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
  1311. }
  1312. func md5s(s string) string {
  1313. h := md5.New()
  1314. h.Write([]byte(s))
  1315. return fmt.Sprintf("%x", h.Sum(nil))
  1316. }
  1317. func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
  1318. // Do one pass over the parameters to see if we're going to send any of
  1319. // them over in binary. If we are, create a paramFormats array at the
  1320. // same time.
  1321. var paramFormats []int
  1322. for i, x := range args {
  1323. _, ok := x.([]byte)
  1324. if ok {
  1325. if paramFormats == nil {
  1326. paramFormats = make([]int, len(args))
  1327. }
  1328. paramFormats[i] = 1
  1329. }
  1330. }
  1331. if paramFormats == nil {
  1332. b.int16(0)
  1333. } else {
  1334. b.int16(len(paramFormats))
  1335. for _, x := range paramFormats {
  1336. b.int16(x)
  1337. }
  1338. }
  1339. b.int16(len(args))
  1340. for _, x := range args {
  1341. if x == nil {
  1342. b.int32(-1)
  1343. } else {
  1344. datum := binaryEncode(&cn.parameterStatus, x)
  1345. b.int32(len(datum))
  1346. b.bytes(datum)
  1347. }
  1348. }
  1349. }
  1350. func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
  1351. if len(args) >= 65536 {
  1352. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
  1353. }
  1354. b := cn.writeBuf('P')
  1355. b.byte(0) // unnamed statement
  1356. b.string(query)
  1357. b.int16(0)
  1358. b.next('B')
  1359. b.int16(0) // unnamed portal and statement
  1360. cn.sendBinaryParameters(b, args)
  1361. b.bytes(colFmtDataAllText)
  1362. b.next('D')
  1363. b.byte('P')
  1364. b.byte(0) // unnamed portal
  1365. b.next('E')
  1366. b.byte(0)
  1367. b.int32(0)
  1368. b.next('S')
  1369. cn.send(b)
  1370. }
  1371. func (c *conn) processParameterStatus(r *readBuf) {
  1372. var err error
  1373. param := r.string()
  1374. switch param {
  1375. case "server_version":
  1376. var major1 int
  1377. var major2 int
  1378. var minor int
  1379. _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor)
  1380. if err == nil {
  1381. c.parameterStatus.serverVersion = major1*10000 + major2*100 + minor
  1382. }
  1383. case "TimeZone":
  1384. c.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
  1385. if err != nil {
  1386. c.parameterStatus.currentLocation = nil
  1387. }
  1388. default:
  1389. // ignore
  1390. }
  1391. }
  1392. func (c *conn) processReadyForQuery(r *readBuf) {
  1393. c.txnStatus = transactionStatus(r.byte())
  1394. }
  1395. func (cn *conn) readReadyForQuery() {
  1396. t, r := cn.recv1()
  1397. switch t {
  1398. case 'Z':
  1399. cn.processReadyForQuery(r)
  1400. return
  1401. default:
  1402. cn.bad = true
  1403. errorf("unexpected message %q; expected ReadyForQuery", t)
  1404. }
  1405. }
  1406. func (cn *conn) readParseResponse() {
  1407. t, r := cn.recv1()
  1408. switch t {
  1409. case '1':
  1410. return
  1411. case 'E':
  1412. err := parseError(r)
  1413. cn.readReadyForQuery()
  1414. panic(err)
  1415. default:
  1416. cn.bad = true
  1417. errorf("unexpected Parse response %q", t)
  1418. }
  1419. }
  1420. func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []oid.Oid) {
  1421. for {
  1422. t, r := cn.recv1()
  1423. switch t {
  1424. case 't':
  1425. nparams := r.int16()
  1426. paramTyps = make([]oid.Oid, nparams)
  1427. for i := range paramTyps {
  1428. paramTyps[i] = r.oid()
  1429. }
  1430. case 'n':
  1431. return paramTyps, nil, nil
  1432. case 'T':
  1433. colNames, colTyps = parseStatementRowDescribe(r)
  1434. return paramTyps, colNames, colTyps
  1435. case 'E':
  1436. err := parseError(r)
  1437. cn.readReadyForQuery()
  1438. panic(err)
  1439. default:
  1440. cn.bad = true
  1441. errorf("unexpected Describe statement response %q", t)
  1442. }
  1443. }
  1444. }
  1445. func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []oid.Oid) {
  1446. t, r := cn.recv1()
  1447. switch t {
  1448. case 'T':
  1449. return parsePortalRowDescribe(r)
  1450. case 'n':
  1451. return nil, nil, nil
  1452. case 'E':
  1453. err := parseError(r)
  1454. cn.readReadyForQuery()
  1455. panic(err)
  1456. default:
  1457. cn.bad = true
  1458. errorf("unexpected Describe response %q", t)
  1459. }
  1460. panic("not reached")
  1461. }
  1462. func (cn *conn) readBindResponse() {
  1463. t, r := cn.recv1()
  1464. switch t {
  1465. case '2':
  1466. return
  1467. case 'E':
  1468. err := parseError(r)
  1469. cn.readReadyForQuery()
  1470. panic(err)
  1471. default:
  1472. cn.bad = true
  1473. errorf("unexpected Bind response %q", t)
  1474. }
  1475. }
  1476. func (cn *conn) postExecuteWorkaround() {
  1477. // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
  1478. // any errors from rows.Next, which masks errors that happened during the
  1479. // execution of the query. To avoid the problem in common cases, we wait
  1480. // here for one more message from the database. If it's not an error the
  1481. // query will likely succeed (or perhaps has already, if it's a
  1482. // CommandComplete), so we push the message into the conn struct; recv1
  1483. // will return it as the next message for rows.Next or rows.Close.
  1484. // However, if it's an error, we wait until ReadyForQuery and then return
  1485. // the error to our caller.
  1486. for {
  1487. t, r := cn.recv1()
  1488. switch t {
  1489. case 'E':
  1490. err := parseError(r)
  1491. cn.readReadyForQuery()
  1492. panic(err)
  1493. case 'C', 'D', 'I':
  1494. // the query didn't fail, but we can't process this message
  1495. cn.saveMessage(t, r)
  1496. return
  1497. default:
  1498. cn.bad = true
  1499. errorf("unexpected message during extended query execution: %q", t)
  1500. }
  1501. }
  1502. }
  1503. // Only for Exec(), since we ignore the returned data
  1504. func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) {
  1505. for {
  1506. t, r := cn.recv1()
  1507. switch t {
  1508. case 'C':
  1509. if err != nil {
  1510. cn.bad = true
  1511. errorf("unexpected CommandComplete after error %s", err)
  1512. }
  1513. res, commandTag = cn.parseComplete(r.string())
  1514. case 'Z':
  1515. cn.processReadyForQuery(r)
  1516. return res, commandTag, err
  1517. case 'E':
  1518. err = parseError(r)
  1519. case 'T', 'D', 'I':
  1520. if err != nil {
  1521. cn.bad = true
  1522. errorf("unexpected %q after error %s", t, err)
  1523. }
  1524. // ignore any results
  1525. default:
  1526. cn.bad = true
  1527. errorf("unknown %s response: %q", protocolState, t)
  1528. }
  1529. }
  1530. }
  1531. func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []oid.Oid) {
  1532. n := r.int16()
  1533. colNames = make([]string, n)
  1534. colTyps = make([]oid.Oid, n)
  1535. for i := range colNames {
  1536. colNames[i] = r.string()
  1537. r.next(6)
  1538. colTyps[i] = r.oid()
  1539. r.next(6)
  1540. // format code not known when describing a statement; always 0
  1541. r.next(2)
  1542. }
  1543. return
  1544. }
  1545. func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []oid.Oid) {
  1546. n := r.int16()
  1547. colNames = make([]string, n)
  1548. colFmts = make([]format, n)
  1549. colTyps = make([]oid.Oid, n)
  1550. for i := range colNames {
  1551. colNames[i] = r.string()
  1552. r.next(6)
  1553. colTyps[i] = r.oid()
  1554. r.next(6)
  1555. colFmts[i] = format(r.int16())
  1556. }
  1557. return
  1558. }
  1559. // parseEnviron tries to mimic some of libpq's environment handling
  1560. //
  1561. // To ease testing, it does not directly reference os.Environ, but is
  1562. // designed to accept its output.
  1563. //
  1564. // Environment-set connection information is intended to have a higher
  1565. // precedence than a library default but lower than any explicitly
  1566. // passed information (such as in the URL or connection string).
  1567. func parseEnviron(env []string) (out map[string]string) {
  1568. out = make(map[string]string)
  1569. for _, v := range env {
  1570. parts := strings.SplitN(v, "=", 2)
  1571. accrue := func(keyname string) {
  1572. out[keyname] = parts[1]
  1573. }
  1574. unsupported := func() {
  1575. panic(fmt.Sprintf("setting %v not supported", parts[0]))
  1576. }
  1577. // The order of these is the same as is seen in the
  1578. // PostgreSQL 9.1 manual. Unsupported but well-defined
  1579. // keys cause a panic; these should be unset prior to
  1580. // execution. Options which pq expects to be set to a
  1581. // certain value are allowed, but must be set to that
  1582. // value if present (they can, of course, be absent).
  1583. switch parts[0] {
  1584. case "PGHOST":
  1585. accrue("host")
  1586. case "PGHOSTADDR":
  1587. unsupported()
  1588. case "PGPORT":
  1589. accrue("port")
  1590. case "PGDATABASE":
  1591. accrue("dbname")
  1592. case "PGUSER":
  1593. accrue("user")
  1594. case "PGPASSWORD":
  1595. accrue("password")
  1596. case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
  1597. unsupported()
  1598. case "PGOPTIONS":
  1599. accrue("options")
  1600. case "PGAPPNAME":
  1601. accrue("application_name")
  1602. case "PGSSLMODE":
  1603. accrue("sslmode")
  1604. case "PGSSLCERT":
  1605. accrue("sslcert")
  1606. case "PGSSLKEY":
  1607. accrue("sslkey")
  1608. case "PGSSLROOTCERT":
  1609. accrue("sslrootcert")
  1610. case "PGREQUIRESSL", "PGSSLCRL":
  1611. unsupported()
  1612. case "PGREQUIREPEER":
  1613. unsupported()
  1614. case "PGKRBSRVNAME", "PGGSSLIB":
  1615. unsupported()
  1616. case "PGCONNECT_TIMEOUT":
  1617. accrue("connect_timeout")
  1618. case "PGCLIENTENCODING":
  1619. accrue("client_encoding")
  1620. case "PGDATESTYLE":
  1621. accrue("datestyle")
  1622. case "PGTZ":
  1623. accrue("timezone")
  1624. case "PGGEQO":
  1625. accrue("geqo")
  1626. case "PGSYSCONFDIR", "PGLOCALEDIR":
  1627. unsupported()
  1628. }
  1629. }
  1630. return out
  1631. }
  1632. // isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
  1633. func isUTF8(name string) bool {
  1634. // Recognize all sorts of silly things as "UTF-8", like Postgres does
  1635. s := strings.Map(alnumLowerASCII, name)
  1636. return s == "utf8" || s == "unicode"
  1637. }
  1638. func alnumLowerASCII(ch rune) rune {
  1639. if 'A' <= ch && ch <= 'Z' {
  1640. return ch + ('a' - 'A')
  1641. }
  1642. if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
  1643. return ch
  1644. }
  1645. return -1 // discard
  1646. }