|
|
@@ -11,15 +11,31 @@ import (
|
|
|
"fmt"
|
|
|
"net"
|
|
|
"strings"
|
|
|
+ "time"
|
|
|
+)
|
|
|
+
|
|
|
+const (
|
|
|
+ // Maximum amount of time we should wait when reading a response from BIRD.
|
|
|
+ responseTimeout = 10 * time.Second
|
|
|
)
|
|
|
|
|
|
// New creates a BIRDClient.
|
|
|
func New(socket string) (*BIRDClient, error) {
|
|
|
+ return newWithTimeout(socket, responseTimeout)
|
|
|
+}
|
|
|
+
|
|
|
+func newWithTimeout(socket string, timeout time.Duration) (*BIRDClient, error) {
|
|
|
conn, err := net.Dial("unix", socket)
|
|
|
if err != nil {
|
|
|
return nil, fmt.Errorf("failed to connect to BIRD: %w", err)
|
|
|
}
|
|
|
- b := &BIRDClient{socket: socket, conn: conn, scanner: bufio.NewScanner(conn)}
|
|
|
+ b := &BIRDClient{
|
|
|
+ socket: socket,
|
|
|
+ conn: conn,
|
|
|
+ scanner: bufio.NewScanner(conn),
|
|
|
+ timeNow: time.Now,
|
|
|
+ timeout: timeout,
|
|
|
+ }
|
|
|
// Read and discard the first line as that is the welcome message.
|
|
|
if _, err := b.readResponse(); err != nil {
|
|
|
return nil, err
|
|
|
@@ -32,6 +48,8 @@ type BIRDClient struct {
|
|
|
socket string
|
|
|
conn net.Conn
|
|
|
scanner *bufio.Scanner
|
|
|
+ timeNow func() time.Time
|
|
|
+ timeout time.Duration
|
|
|
}
|
|
|
|
|
|
// Close closes the underlying connection to BIRD.
|
|
|
@@ -81,10 +99,15 @@ func (b *BIRDClient) EnableProtocol(protocol string) error {
|
|
|
// 1 means ‘table entry’, 8 ‘runtime error’ and 9 ‘syntax error’.
|
|
|
|
|
|
func (b *BIRDClient) exec(cmd string, args ...any) (string, error) {
|
|
|
+ if err := b.conn.SetWriteDeadline(b.timeNow().Add(b.timeout)); err != nil {
|
|
|
+ return "", err
|
|
|
+ }
|
|
|
if _, err := fmt.Fprintf(b.conn, cmd, args...); err != nil {
|
|
|
return "", err
|
|
|
}
|
|
|
- fmt.Fprintln(b.conn)
|
|
|
+ if _, err := fmt.Fprintln(b.conn); err != nil {
|
|
|
+ return "", err
|
|
|
+ }
|
|
|
return b.readResponse()
|
|
|
}
|
|
|
|
|
|
@@ -105,14 +128,20 @@ func hasResponseCode(s []byte) bool {
|
|
|
}
|
|
|
|
|
|
func (b *BIRDClient) readResponse() (string, error) {
|
|
|
+ // Set the read timeout before we start reading anything.
|
|
|
+ if err := b.conn.SetReadDeadline(b.timeNow().Add(b.timeout)); err != nil {
|
|
|
+ return "", err
|
|
|
+ }
|
|
|
+
|
|
|
var resp strings.Builder
|
|
|
var done bool
|
|
|
for !done {
|
|
|
if !b.scanner.Scan() {
|
|
|
- return "", fmt.Errorf("reading response from bird failed: %q", resp.String())
|
|
|
- }
|
|
|
- if err := b.scanner.Err(); err != nil {
|
|
|
- return "", err
|
|
|
+ if err := b.scanner.Err(); err != nil {
|
|
|
+ return "", err
|
|
|
+ }
|
|
|
+
|
|
|
+ return "", fmt.Errorf("reading response from bird failed (EOF): %q", resp.String())
|
|
|
}
|
|
|
out := b.scanner.Bytes()
|
|
|
if _, err := resp.Write(out); err != nil {
|