https://github.com/etcd-io/etcd
Raw File
Tip revision: 8a03d2e9614b8192ebaa5a25ef92f6ff62e3593c authored by Gyuho Lee on 25 November 2020, 19:31:52 UTC
version: 3.4.14
Tip revision: 8a03d2e
reverse.go
// Copyright 2015 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package httpproxy

import (
	"bytes"
	"context"
	"fmt"
	"io"
	"io/ioutil"
	"net"
	"net/http"
	"net/url"
	"strings"
	"sync/atomic"
	"time"

	"github.com/coreos/pkg/capnslog"
	"go.etcd.io/etcd/etcdserver/api/v2http/httptypes"
)

var (
	plog = capnslog.NewPackageLogger("go.etcd.io/etcd", "proxy/httpproxy")

	// Hop-by-hop headers. These are removed when sent to the backend.
	// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
	// This list of headers borrowed from stdlib httputil.ReverseProxy
	singleHopHeaders = []string{
		"Connection",
		"Keep-Alive",
		"Proxy-Authenticate",
		"Proxy-Authorization",
		"Te", // canonicalized version of "TE"
		"Trailers",
		"Transfer-Encoding",
		"Upgrade",
	}
)

func removeSingleHopHeaders(hdrs *http.Header) {
	for _, h := range singleHopHeaders {
		hdrs.Del(h)
	}
}

type reverseProxy struct {
	director  *director
	transport http.RoundTripper
}

func (p *reverseProxy) ServeHTTP(rw http.ResponseWriter, clientreq *http.Request) {
	reportIncomingRequest(clientreq)
	proxyreq := new(http.Request)
	*proxyreq = *clientreq
	startTime := time.Now()

	var (
		proxybody []byte
		err       error
	)

	if clientreq.Body != nil {
		proxybody, err = ioutil.ReadAll(clientreq.Body)
		if err != nil {
			msg := fmt.Sprintf("failed to read request body: %v", err)
			plog.Println(msg)
			e := httptypes.NewHTTPError(http.StatusInternalServerError, "httpproxy: "+msg)
			if we := e.WriteTo(rw); we != nil {
				plog.Debugf("error writing HTTPError (%v) to %s", we, clientreq.RemoteAddr)
			}
			return
		}
	}

	// deep-copy the headers, as these will be modified below
	proxyreq.Header = make(http.Header)
	copyHeader(proxyreq.Header, clientreq.Header)

	normalizeRequest(proxyreq)
	removeSingleHopHeaders(&proxyreq.Header)
	maybeSetForwardedFor(proxyreq)

	endpoints := p.director.endpoints()
	if len(endpoints) == 0 {
		msg := "zero endpoints currently available"
		reportRequestDropped(clientreq, zeroEndpoints)

		// TODO: limit the rate of the error logging.
		plog.Println(msg)
		e := httptypes.NewHTTPError(http.StatusServiceUnavailable, "httpproxy: "+msg)
		if we := e.WriteTo(rw); we != nil {
			plog.Debugf("error writing HTTPError (%v) to %s", we, clientreq.RemoteAddr)
		}
		return
	}

	var requestClosed int32
	completeCh := make(chan bool, 1)
	closeNotifier, ok := rw.(http.CloseNotifier)
	ctx, cancel := context.WithCancel(context.Background())
	proxyreq = proxyreq.WithContext(ctx)
	defer cancel()
	if ok {
		closeCh := closeNotifier.CloseNotify()
		go func() {
			select {
			case <-closeCh:
				atomic.StoreInt32(&requestClosed, 1)
				plog.Printf("client %v closed request prematurely", clientreq.RemoteAddr)
				cancel()
			case <-completeCh:
			}
		}()

		defer func() {
			completeCh <- true
		}()
	}

	var res *http.Response

	for _, ep := range endpoints {
		if proxybody != nil {
			proxyreq.Body = ioutil.NopCloser(bytes.NewBuffer(proxybody))
		}
		redirectRequest(proxyreq, ep.URL)

		res, err = p.transport.RoundTrip(proxyreq)
		if atomic.LoadInt32(&requestClosed) == 1 {
			return
		}
		if err != nil {
			reportRequestDropped(clientreq, failedSendingRequest)
			plog.Printf("failed to direct request to %s: %v", ep.URL.String(), err)
			ep.Failed()
			continue
		}

		break
	}

	if res == nil {
		// TODO: limit the rate of the error logging.
		msg := fmt.Sprintf("unable to get response from %d endpoint(s)", len(endpoints))
		reportRequestDropped(clientreq, failedGettingResponse)
		plog.Println(msg)
		e := httptypes.NewHTTPError(http.StatusBadGateway, "httpproxy: "+msg)
		if we := e.WriteTo(rw); we != nil {
			plog.Debugf("error writing HTTPError (%v) to %s", we, clientreq.RemoteAddr)
		}
		return
	}

	defer res.Body.Close()
	reportRequestHandled(clientreq, res, startTime)
	removeSingleHopHeaders(&res.Header)
	copyHeader(rw.Header(), res.Header)

	rw.WriteHeader(res.StatusCode)
	io.Copy(rw, res.Body)
}

func copyHeader(dst, src http.Header) {
	for k, vv := range src {
		for _, v := range vv {
			dst.Add(k, v)
		}
	}
}

func redirectRequest(req *http.Request, loc url.URL) {
	req.URL.Scheme = loc.Scheme
	req.URL.Host = loc.Host
}

func normalizeRequest(req *http.Request) {
	req.Proto = "HTTP/1.1"
	req.ProtoMajor = 1
	req.ProtoMinor = 1
	req.Close = false
}

func maybeSetForwardedFor(req *http.Request) {
	clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
	if err != nil {
		return
	}

	// If we aren't the first proxy retain prior
	// X-Forwarded-For information as a comma+space
	// separated list and fold multiple headers into one.
	if prior, ok := req.Header["X-Forwarded-For"]; ok {
		clientIP = strings.Join(prior, ", ") + ", " + clientIP
	}
	req.Header.Set("X-Forwarded-For", clientIP)
}
back to top