Revision 9a5533382d84999e4e79642e1ec0f8bfa9b70ba8 authored by Wenjia Zhang on 19 July 2024, 20:13:00 UTC, committed by Wenjia Zhang on 19 July 2024, 20:13:00 UTC
1 parent d6479db
Raw File
client_test.go
// Copyright 2016 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package clientv3

import (
	"context"
	"fmt"
	"io"
	"net"
	"sync"
	"testing"
	"time"

	"go.etcd.io/etcd/api/v3/etcdserverpb"
	"go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
	"go.etcd.io/etcd/client/pkg/v3/testutil"
	"go.uber.org/zap"
	"go.uber.org/zap/zaptest"

	"google.golang.org/grpc"
)

func NewClient(t *testing.T, cfg Config) (*Client, error) {
	cfg.Logger = zaptest.NewLogger(t)
	return New(cfg)
}

func TestDialCancel(t *testing.T) {
	testutil.RegisterLeakDetection(t)

	// accept first connection so client is created with dial timeout
	ln, err := net.Listen("unix", "dialcancel:12345")
	if err != nil {
		t.Fatal(err)
	}
	defer ln.Close()

	ep := "unix://dialcancel:12345"
	cfg := Config{
		Endpoints:   []string{ep},
		DialTimeout: 30 * time.Second}
	c, err := NewClient(t, cfg)
	if err != nil {
		t.Fatal(err)
	}

	// connect to ipv4 black hole so dial blocks
	c.SetEndpoints("http://254.0.0.1:12345")

	// issue Get to force redial attempts
	getc := make(chan struct{})
	go func() {
		defer close(getc)
		// Get may hang forever on grpc's Stream.Header() if its
		// context is never canceled.
		c.Get(c.Ctx(), "abc")
	}()

	// wait a little bit so client close is after dial starts
	time.Sleep(100 * time.Millisecond)

	donec := make(chan struct{})
	go func() {
		defer close(donec)
		c.Close()
	}()

	select {
	case <-time.After(5 * time.Second):
		t.Fatalf("failed to close")
	case <-donec:
	}
	select {
	case <-time.After(5 * time.Second):
		t.Fatalf("get failed to exit")
	case <-getc:
	}
}

func TestDialTimeout(t *testing.T) {
	testutil.RegisterLeakDetection(t)

	wantError := context.DeadlineExceeded

	// grpc.WithBlock to block until connection up or timeout
	testCfgs := []Config{
		{
			Endpoints:   []string{"http://254.0.0.1:12345"},
			DialTimeout: 2 * time.Second,
			DialOptions: []grpc.DialOption{grpc.WithBlock()},
		},
		{
			Endpoints:   []string{"http://254.0.0.1:12345"},
			DialTimeout: time.Second,
			DialOptions: []grpc.DialOption{grpc.WithBlock()},
			Username:    "abc",
			Password:    "def",
		},
	}

	for i, cfg := range testCfgs {
		donec := make(chan error, 1)
		go func(cfg Config) {
			// without timeout, dial continues forever on ipv4 black hole
			c, err := NewClient(t, cfg)
			if c != nil || err == nil {
				t.Errorf("#%d: new client should fail", i)
			}
			donec <- err
		}(cfg)

		time.Sleep(10 * time.Millisecond)

		select {
		case err := <-donec:
			t.Errorf("#%d: dial didn't wait (%v)", i, err)
		default:
		}

		select {
		case <-time.After(5 * time.Second):
			t.Errorf("#%d: failed to timeout dial on time", i)
		case err := <-donec:
			if err.Error() != wantError.Error() {
				t.Errorf("#%d: unexpected error '%v', want '%v'", i, err, wantError)
			}
		}
	}
}

func TestDialNoTimeout(t *testing.T) {
	cfg := Config{Endpoints: []string{"127.0.0.1:12345"}}
	c, err := NewClient(t, cfg)
	if c == nil || err != nil {
		t.Fatalf("new client with DialNoWait should succeed, got %v", err)
	}
	c.Close()
}

func TestMaxUnaryRetries(t *testing.T) {
	maxUnaryRetries := uint(10)
	cfg := Config{
		Endpoints:       []string{"127.0.0.1:12345"},
		MaxUnaryRetries: maxUnaryRetries,
	}
	c, err := NewClient(t, cfg)
	if c == nil || err != nil {
		t.Fatalf("new client with MaxUnaryRetries should succeed, got %v", err)
	}
	defer c.Close()

	if c.cfg.MaxUnaryRetries != maxUnaryRetries {
		t.Fatalf("client MaxUnaryRetries should be %d, got %d", maxUnaryRetries, c.cfg.MaxUnaryRetries)
	}
}

func TestBackoff(t *testing.T) {
	backoffWaitBetween := 100 * time.Millisecond
	cfg := Config{
		Endpoints:          []string{"127.0.0.1:12345"},
		BackoffWaitBetween: backoffWaitBetween,
	}
	c, err := NewClient(t, cfg)
	if c == nil || err != nil {
		t.Fatalf("new client with BackoffWaitBetween should succeed, got %v", err)
	}
	defer c.Close()

	if c.cfg.BackoffWaitBetween != backoffWaitBetween {
		t.Fatalf("client BackoffWaitBetween should be %v, got %v", backoffWaitBetween, c.cfg.BackoffWaitBetween)
	}
}

func TestBackoffJitterFraction(t *testing.T) {
	backoffJitterFraction := float64(0.9)
	cfg := Config{
		Endpoints:             []string{"127.0.0.1:12345"},
		BackoffJitterFraction: backoffJitterFraction,
	}
	c, err := NewClient(t, cfg)
	if c == nil || err != nil {
		t.Fatalf("new client with BackoffJitterFraction should succeed, got %v", err)
	}
	defer c.Close()

	if c.cfg.BackoffJitterFraction != backoffJitterFraction {
		t.Fatalf("client BackoffJitterFraction should be %v, got %v", backoffJitterFraction, c.cfg.BackoffJitterFraction)
	}
}

func TestIsHaltErr(t *testing.T) {
	if !isHaltErr(context.TODO(), fmt.Errorf("etcdserver: some etcdserver error")) {
		t.Errorf(`error prefixed with "etcdserver: " should be Halted by default`)
	}
	if isHaltErr(context.TODO(), rpctypes.ErrGRPCStopped) {
		t.Errorf("error %v should not halt", rpctypes.ErrGRPCStopped)
	}
	if isHaltErr(context.TODO(), rpctypes.ErrGRPCNoLeader) {
		t.Errorf("error %v should not halt", rpctypes.ErrGRPCNoLeader)
	}
	ctx, cancel := context.WithCancel(context.TODO())
	if isHaltErr(ctx, nil) {
		t.Errorf("no error and active context should not be Halted")
	}
	cancel()
	if !isHaltErr(ctx, nil) {
		t.Errorf("cancel on context should be Halted")
	}
}

func TestCloseCtxClient(t *testing.T) {
	ctx := context.Background()
	c := NewCtxClient(ctx)
	err := c.Close()
	// Close returns ctx.toErr, a nil error means an open Done channel
	if err == nil {
		t.Errorf("failed to Close the client. %v", err)
	}
}

func TestWithLogger(t *testing.T) {
	ctx := context.Background()
	c := NewCtxClient(ctx)
	if c.lg == nil {
		t.Errorf("unexpected nil in *zap.Logger")
	}

	c.WithLogger(nil)
	if c.lg != nil {
		t.Errorf("WithLogger should modify *zap.Logger")
	}
}

func TestZapWithLogger(t *testing.T) {
	ctx := context.Background()
	lg := zap.NewNop()
	c := NewCtxClient(ctx, WithZapLogger(lg))

	if c.lg != lg {
		t.Errorf("WithZapLogger should modify *zap.Logger")
	}
}

func TestAuthTokenBundleNoOverwrite(t *testing.T) {
	// Create a mock AuthServer to handle Authenticate RPCs.
	lis, err := net.Listen("unix", "etcd-auth-test:0")
	if err != nil {
		t.Fatal(err)
	}
	defer lis.Close()
	addr := "unix:" + lis.Addr().String()
	srv := grpc.NewServer()
	etcdserverpb.RegisterAuthServer(srv, mockAuthServer{})
	go srv.Serve(lis)
	defer srv.Stop()

	// Create a client, which should call Authenticate on the mock server to
	// exchange username/password for an auth token.
	c, err := NewClient(t, Config{
		DialTimeout: 5 * time.Second,
		Endpoints:   []string{addr},
		Username:    "foo",
		Password:    "bar",
	})
	if err != nil {
		t.Fatal(err)
	}
	defer c.Close()
	oldTokenBundle := c.authTokenBundle

	// Call the public Dial again, which should preserve the original
	// authTokenBundle.
	gc, err := c.Dial(addr)
	if err != nil {
		t.Fatal(err)
	}
	defer gc.Close()
	newTokenBundle := c.authTokenBundle

	if oldTokenBundle != newTokenBundle {
		t.Error("Client.authTokenBundle has been overwritten during Client.Dial")
	}
}

type mockAuthServer struct {
	*etcdserverpb.UnimplementedAuthServer
}

func (mockAuthServer) Authenticate(context.Context, *etcdserverpb.AuthenticateRequest) (*etcdserverpb.AuthenticateResponse, error) {
	return &etcdserverpb.AuthenticateResponse{Token: "mock-token"}, nil
}

func TestSyncFiltersMembers(t *testing.T) {
	c, _ := NewClient(t, Config{Endpoints: []string{"http://254.0.0.1:12345"}})
	defer c.Close()
	c.Cluster = &mockCluster{
		[]*etcdserverpb.Member{
			{ID: 0, Name: "", ClientURLs: []string{"http://254.0.0.1:12345"}, IsLearner: false},
			{ID: 1, Name: "isStarted", ClientURLs: []string{"http://254.0.0.2:12345"}, IsLearner: true},
			{ID: 2, Name: "isStartedAndNotLearner", ClientURLs: []string{"http://254.0.0.3:12345"}, IsLearner: false},
		},
	}
	c.Sync(context.Background())

	endpoints := c.Endpoints()
	if len(endpoints) != 1 || endpoints[0] != "http://254.0.0.3:12345" {
		t.Error("Client.Sync uses learner and/or non-started member client URLs")
	}
}

type mockCluster struct {
	members []*etcdserverpb.Member
}

func (mc *mockCluster) MemberList(ctx context.Context) (*MemberListResponse, error) {
	return &MemberListResponse{Members: mc.members}, nil
}

func (mc *mockCluster) MemberAdd(ctx context.Context, peerAddrs []string) (*MemberAddResponse, error) {
	return nil, nil
}

func (mc *mockCluster) MemberAddAsLearner(ctx context.Context, peerAddrs []string) (*MemberAddResponse, error) {
	return nil, nil
}

func (mc *mockCluster) MemberRemove(ctx context.Context, id uint64) (*MemberRemoveResponse, error) {
	return nil, nil
}

func (mc *mockCluster) MemberUpdate(ctx context.Context, id uint64, peerAddrs []string) (*MemberUpdateResponse, error) {
	return nil, nil
}

func (mc *mockCluster) MemberPromote(ctx context.Context, id uint64) (*MemberPromoteResponse, error) {
	return nil, nil
}

func TestClientRejectOldCluster(t *testing.T) {
	testutil.RegisterLeakDetection(t)
	var tests = []struct {
		name          string
		endpoints     []string
		versions      []string
		expectedError error
	}{
		{
			name:          "all new versions with the same value",
			endpoints:     []string{"192.168.3.41:22379", "192.168.3.41:22479", "192.168.3.41:22579"},
			versions:      []string{"3.5.4", "3.5.4", "3.5.4"},
			expectedError: nil,
		},
		{
			name:          "all new versions with different values",
			endpoints:     []string{"192.168.3.41:22379", "192.168.3.41:22479", "192.168.3.41:22579"},
			versions:      []string{"3.5.4", "3.5.4", "3.4.0"},
			expectedError: nil,
		},
		{
			name:          "all old versions with different values",
			endpoints:     []string{"192.168.3.41:22379", "192.168.3.41:22479", "192.168.3.41:22579"},
			versions:      []string{"3.3.0", "3.3.0", "3.4.0"},
			expectedError: ErrOldCluster,
		},
		{
			name:          "all old versions with the same value",
			endpoints:     []string{"192.168.3.41:22379", "192.168.3.41:22479", "192.168.3.41:22579"},
			versions:      []string{"3.3.0", "3.3.0", "3.3.0"},
			expectedError: ErrOldCluster,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if len(tt.endpoints) != len(tt.versions) || len(tt.endpoints) == 0 {
				t.Errorf("Unexpected endpoints and versions length, len(endpoints):%d, len(versions):%d", len(tt.endpoints), len(tt.versions))
				return
			}
			endpointToVersion := make(map[string]string)
			for j := range tt.endpoints {
				endpointToVersion[tt.endpoints[j]] = tt.versions[j]
			}
			c := &Client{
				ctx: context.Background(),
				cfg: Config{
					Endpoints: tt.endpoints,
				},
				mu: new(sync.RWMutex),
				Maintenance: &mockMaintenance{
					Version: endpointToVersion,
				},
			}

			if err := c.checkVersion(); err != tt.expectedError {
				t.Errorf("heckVersion err:%v", err)
			}
		})

	}

}

type mockMaintenance struct {
	Version map[string]string
}

func (mm mockMaintenance) Status(ctx context.Context, endpoint string) (*StatusResponse, error) {
	return &StatusResponse{Version: mm.Version[endpoint]}, nil
}

func (mm mockMaintenance) AlarmList(ctx context.Context) (*AlarmResponse, error) {
	return nil, nil
}

func (mm mockMaintenance) AlarmDisarm(ctx context.Context, m *AlarmMember) (*AlarmResponse, error) {
	return nil, nil
}

func (mm mockMaintenance) Defragment(ctx context.Context, endpoint string) (*DefragmentResponse, error) {
	return nil, nil
}

func (mm mockMaintenance) HashKV(ctx context.Context, endpoint string, rev int64) (*HashKVResponse, error) {
	return nil, nil
}

func (mm mockMaintenance) Snapshot(ctx context.Context) (io.ReadCloser, error) {
	return nil, nil
}

func (mm mockMaintenance) MoveLeader(ctx context.Context, transfereeID uint64) (*MoveLeaderResponse, error) {
	return nil, nil
}
back to top