1// Copyright 2018 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package transport provides a mechanism to send requests with https cert,
16// key, and CA.
17package transport
18
19import (
20	"crypto/tls"
21	"crypto/x509"
22	"fmt"
23	"net/http"
24	"os"
25	"sync"
26
27	"github.com/google/pprof/internal/plugin"
28)
29
30type transport struct {
31	cert       *string
32	key        *string
33	ca         *string
34	caCertPool *x509.CertPool
35	certs      []tls.Certificate
36	initOnce   sync.Once
37	initErr    error
38}
39
40const extraUsage = `    -tls_cert             TLS client certificate file for fetching profile and symbols
41    -tls_key              TLS private key file for fetching profile and symbols
42    -tls_ca               TLS CA certs file for fetching profile and symbols`
43
44// New returns a round tripper for making requests with the
45// specified cert, key, and ca. The flags tls_cert, tls_key, and tls_ca are
46// added to the flagset to allow a user to specify the cert, key, and ca. If
47// the flagset is nil, no flags will be added, and users will not be able to
48// use these flags.
49func New(flagset plugin.FlagSet) http.RoundTripper {
50	if flagset == nil {
51		return &transport{}
52	}
53	flagset.AddExtraUsage(extraUsage)
54	return &transport{
55		cert: flagset.String("tls_cert", "", "TLS client certificate file for fetching profile and symbols"),
56		key:  flagset.String("tls_key", "", "TLS private key file for fetching profile and symbols"),
57		ca:   flagset.String("tls_ca", "", "TLS CA certs file for fetching profile and symbols"),
58	}
59}
60
61// initialize uses the cert, key, and ca to initialize the certs
62// to use these when making requests.
63func (tr *transport) initialize() error {
64	var cert, key, ca string
65	if tr.cert != nil {
66		cert = *tr.cert
67	}
68	if tr.key != nil {
69		key = *tr.key
70	}
71	if tr.ca != nil {
72		ca = *tr.ca
73	}
74
75	if cert != "" && key != "" {
76		tlsCert, err := tls.LoadX509KeyPair(cert, key)
77		if err != nil {
78			return fmt.Errorf("could not load certificate/key pair specified by -tls_cert and -tls_key: %v", err)
79		}
80		tr.certs = []tls.Certificate{tlsCert}
81	} else if cert == "" && key != "" {
82		return fmt.Errorf("-tls_key is specified, so -tls_cert must also be specified")
83	} else if cert != "" && key == "" {
84		return fmt.Errorf("-tls_cert is specified, so -tls_key must also be specified")
85	}
86
87	if ca != "" {
88		caCertPool := x509.NewCertPool()
89		caCert, err := os.ReadFile(ca)
90		if err != nil {
91			return fmt.Errorf("could not load CA specified by -tls_ca: %v", err)
92		}
93		caCertPool.AppendCertsFromPEM(caCert)
94		tr.caCertPool = caCertPool
95	}
96
97	return nil
98}
99
100// RoundTrip executes a single HTTP transaction, returning
101// a Response for the provided Request.
102func (tr *transport) RoundTrip(req *http.Request) (*http.Response, error) {
103	tr.initOnce.Do(func() {
104		tr.initErr = tr.initialize()
105	})
106	if tr.initErr != nil {
107		return nil, tr.initErr
108	}
109
110	tlsConfig := &tls.Config{
111		RootCAs:      tr.caCertPool,
112		Certificates: tr.certs,
113	}
114
115	if req.URL.Scheme == "https+insecure" {
116		// Make shallow copy of request, and req.URL, so the request's URL can be
117		// modified.
118		r := *req
119		*r.URL = *req.URL
120		req = &r
121		tlsConfig.InsecureSkipVerify = true
122		req.URL.Scheme = "https"
123	}
124
125	transport := http.Transport{
126		Proxy:           http.ProxyFromEnvironment,
127		TLSClientConfig: tlsConfig,
128	}
129
130	return transport.RoundTrip(req)
131}
132