1// Copyright 2017 The Bazel Authors. All rights reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15// protoc invokes the protobuf compiler and captures the resulting .pb.go file. 16package main 17 18import ( 19 "bytes" 20 "errors" 21 "flag" 22 "fmt" 23 "io/ioutil" 24 "log" 25 "os" 26 "os/exec" 27 "path/filepath" 28 "runtime" 29 "strings" 30) 31 32type genFileInfo struct { 33 base string // The basename of the path 34 path string // The full path to the final file 35 expected bool // Whether the file is expected by the rules 36 created bool // Whether the file was created by protoc 37 from *genFileInfo // The actual file protoc produced if not Path 38 unique bool // True if this base name is unique in expected results 39 ambiguious bool // True if there were more than one possible outputs that matched this file 40} 41 42func run(args []string) error { 43 // process the args 44 args, useParamFile, err := expandParamsFiles(args) 45 if err != nil { 46 return err 47 } 48 options := multiFlag{} 49 descriptors := multiFlag{} 50 expected := multiFlag{} 51 imports := multiFlag{} 52 flags := flag.NewFlagSet("protoc", flag.ExitOnError) 53 protoc := flags.String("protoc", "", "The path to the real protoc.") 54 outPath := flags.String("out_path", "", "The base output path to write to.") 55 plugin := flags.String("plugin", "", "The go plugin to use.") 56 importpath := flags.String("importpath", "", "The importpath for the generated sources.") 57 flags.Var(&options, "option", "The plugin options.") 58 flags.Var(&descriptors, "descriptor_set", "The descriptor set to read.") 59 flags.Var(&expected, "expected", "The expected output files.") 60 flags.Var(&imports, "import", "Map a proto file to an import path.") 61 if err := flags.Parse(args); err != nil { 62 return err 63 } 64 65 // Output to a temporary folder and then move the contents into place below. 66 // This is to work around long file paths on Windows. 67 tmpDir, err := ioutil.TempDir("", "go_proto") 68 if err != nil { 69 return err 70 } 71 tmpDir = abs(tmpDir) // required to work with long paths on Windows 72 absOutPath := abs(*outPath) // required to work with long paths on Windows 73 defer os.RemoveAll(tmpDir) 74 75 pluginBase := filepath.Base(*plugin) 76 pluginName := strings.TrimSuffix( 77 strings.TrimPrefix(filepath.Base(*plugin), "protoc-gen-"), ".exe") 78 for _, m := range imports { 79 options = append(options, fmt.Sprintf("M%v", m)) 80 } 81 if runtime.GOOS == "windows" { 82 // Turn the plugin path into raw form, since we're handing it off to a non-go binary. 83 // This is required to work with long paths on Windows. 84 *plugin = "\\\\?\\" + abs(*plugin) 85 } 86 protoc_args := []string{ 87 fmt.Sprintf("--%v_out=%v:%v", pluginName, strings.Join(options, ","), tmpDir), 88 "--plugin", fmt.Sprintf("%v=%v", strings.TrimSuffix(pluginBase, ".exe"), *plugin), 89 "--descriptor_set_in", strings.Join(descriptors, string(os.PathListSeparator)), 90 } 91 protoc_args = append(protoc_args, flags.Args()...) 92 93 var cmd *exec.Cmd 94 if useParamFile { 95 paramFile, err := ioutil.TempFile(tmpDir, "protoc-*.params") 96 if err != nil { 97 return fmt.Errorf("error creating param file for protoc: %v", err) 98 } 99 for _, arg := range protoc_args { 100 _, err := fmt.Fprintln(paramFile, arg) 101 if err != nil { 102 return fmt.Errorf("error writing param file for protoc: %v", err) 103 } 104 } 105 cmd = exec.Command(*protoc, "@"+paramFile.Name()) 106 } else { 107 cmd = exec.Command(*protoc, protoc_args...) 108 } 109 110 cmd.Stdout = os.Stdout 111 cmd.Stderr = os.Stderr 112 if err := cmd.Run(); err != nil { 113 return fmt.Errorf("error running protoc: %v", err) 114 } 115 // Build our file map, and test for existance 116 files := map[string]*genFileInfo{} 117 byBase := map[string]*genFileInfo{} 118 for _, path := range expected { 119 info := &genFileInfo{ 120 path: path, 121 base: filepath.Base(path), 122 expected: true, 123 unique: true, 124 } 125 files[info.path] = info 126 if byBase[info.base] != nil { 127 info.unique = false 128 byBase[info.base].unique = false 129 } else { 130 byBase[info.base] = info 131 } 132 } 133 // Walk the generated files 134 filepath.Walk(tmpDir, func(path string, f os.FileInfo, err error) error { 135 relPath, err := filepath.Rel(tmpDir, path) 136 if err != nil { 137 return err 138 } 139 if relPath == "." { 140 return nil 141 } 142 143 if f.IsDir() { 144 if err := os.Mkdir(filepath.Join(absOutPath, relPath), f.Mode()); !os.IsExist(err) { 145 return err 146 } 147 return nil 148 } 149 150 if !strings.HasSuffix(path, ".go") { 151 return nil 152 } 153 154 info := &genFileInfo{ 155 path: path, 156 base: filepath.Base(path), 157 created: true, 158 } 159 160 if foundInfo, ok := files[relPath]; ok { 161 foundInfo.created = true 162 foundInfo.from = info 163 return nil 164 } 165 files[relPath] = info 166 copyTo := byBase[info.base] 167 switch { 168 case copyTo == nil: 169 // Unwanted output 170 case !copyTo.unique: 171 // not unique, no copy allowed 172 case copyTo.from != nil: 173 copyTo.ambiguious = true 174 info.ambiguious = true 175 default: 176 copyTo.from = info 177 copyTo.created = true 178 info.expected = true 179 } 180 return nil 181 }) 182 buf := &bytes.Buffer{} 183 for _, f := range files { 184 switch { 185 case f.expected && !f.created: 186 // Some plugins only create output files if the proto source files have 187 // have relevant definitions (e.g., services for grpc_gateway). Create 188 // trivial files that the compiler will ignore for missing outputs. 189 data := []byte("// +build ignore\n\npackage ignore") 190 if err := ioutil.WriteFile(abs(f.path), data, 0644); err != nil { 191 return err 192 } 193 case f.expected && f.ambiguious: 194 fmt.Fprintf(buf, "Ambiguious output %v.\n", f.path) 195 case f.from != nil: 196 data, err := ioutil.ReadFile(f.from.path) 197 if err != nil { 198 return err 199 } 200 if err := ioutil.WriteFile(abs(f.path), data, 0644); err != nil { 201 return err 202 } 203 case !f.expected: 204 //fmt.Fprintf(buf, "Unexpected output %v.\n", f.path) 205 } 206 if buf.Len() > 0 { 207 fmt.Fprintf(buf, "Check that the go_package option is %q.", *importpath) 208 return errors.New(buf.String()) 209 } 210 } 211 212 return nil 213} 214 215func main() { 216 if err := run(os.Args[1:]); err != nil { 217 log.Fatal(err) 218 } 219} 220