diff --git a/main.go b/main.go index 16310ba..10be4ec 100644 --- a/main.go +++ b/main.go @@ -49,6 +49,7 @@ Flags: var ( skipExtensionFlags skipExtensionFlag + spdx spdxFlag holder = flag.String("c", "Google LLC", "copyright holder") license = flag.String("l", "apache", "license type: apache, bsd, mit, mpl") @@ -58,6 +59,15 @@ var ( checkonly = flag.Bool("check", false, "check only mode: verify presence of license headers and exit with non-zero code if missing") ) +func init() { + flag.Usage = func() { + fmt.Fprintln(os.Stderr, helpText) + flag.PrintDefaults() + } + flag.Var(&skipExtensionFlags, "skip", "To skip files to check/add the header file, for example: -skip rb -skip go") + flag.Var(&spdx, "s", "Include SPDX identifier in license header. Set -s=only to only include SPDX identifier.") +} + type skipExtensionFlag []string func (i *skipExtensionFlag) String() string { @@ -69,41 +79,54 @@ func (i *skipExtensionFlag) Set(value string) error { return nil } -func main() { - flag.Usage = func() { - fmt.Fprintln(os.Stderr, helpText) - flag.PrintDefaults() +// spdxFlag defines the line flag behavior for specifying SPDX support. +type spdxFlag string + +const ( + spdxOff spdxFlag = "" + spdxOn spdxFlag = "true" // value set by flag package on bool flag + spdxOnly spdxFlag = "only" +) + +// IsBoolFlag causes a bare '-s' flag to be set as the string 'true'. This +// allows the use of the bare '-s' or setting a string '-s=only'. +func (i *spdxFlag) IsBoolFlag() bool { return true } +func (i *spdxFlag) String() string { return string(*i) } + +func (i *spdxFlag) Set(value string) error { + v := spdxFlag(value) + if v != spdxOn && v != spdxOnly { + return fmt.Errorf("error: flag 's' expects '%v' or '%v'", spdxOn, spdxOnly) } - flag.Var(&skipExtensionFlags, "skip", "To skip files to check/add the header file, for example: -skip rb -skip go") + *i = v + return nil +} + +func main() { flag.Parse() if flag.NArg() == 0 { flag.Usage() os.Exit(1) } - data := ©rightData{ - Year: *year, - Holder: *holder, + // map legacy license values + if t, ok := legacyLicenseTypes[*license]; ok { + *license = t } - var t *template.Template - if *licensef != "" { - d, err := ioutil.ReadFile(*licensef) - if err != nil { - log.Printf("license file: %v", err) - os.Exit(1) - } - t, err = template.New("").Parse(string(d)) - if err != nil { - log.Printf("license file: %v", err) - os.Exit(1) - } - } else { - t = licenseTemplate[*license] - if t == nil { - log.Printf("unknown license: %s", *license) - os.Exit(1) - } + data := licenseData{ + Year: *year, + Holder: *holder, + SPDXID: *license, + } + + tpl, err := fetchTemplate(*license, *licensef, spdx) + if err != nil { + log.Fatal(err) + } + t, err := template.New("").Parse(tpl) + if err != nil { + log.Fatal(err) } // process at most 1000 files in parallel @@ -189,7 +212,7 @@ func walk(ch chan<- *file, start string) { // addLicense add a license to the file if missing. // // It returns true if the file was updated. -func addLicense(path string, fmode os.FileMode, tmpl *template.Template, data *copyrightData) (bool, error) { +func addLicense(path string, fmode os.FileMode, tmpl *template.Template, data licenseData) (bool, error) { var lic []byte var err error lic, err = licenseHeader(path, tmpl, data) @@ -227,32 +250,32 @@ func fileHasLicense(path string) (bool, error) { return hasLicense(b) || isGenerated(b), nil } -func licenseHeader(path string, tmpl *template.Template, data *copyrightData) ([]byte, error) { +func licenseHeader(path string, tmpl *template.Template, data licenseData) ([]byte, error) { var lic []byte var err error switch fileExtension(path) { default: return nil, nil case ".c", ".h", ".gv": - lic, err = prefix(tmpl, data, "/*", " * ", " */") + lic, err = executeTemplate(tmpl, data, "/*", " * ", " */") case ".js", ".mjs", ".cjs", ".jsx", ".tsx", ".css", ".scss", ".sass", ".tf", ".ts": - lic, err = prefix(tmpl, data, "/**", " * ", " */") + lic, err = executeTemplate(tmpl, data, "/**", " * ", " */") case ".cc", ".cpp", ".cs", ".go", ".hh", ".hpp", ".java", ".m", ".mm", ".proto", ".rs", ".scala", ".swift", ".dart", ".groovy", ".kt", ".kts", ".v", ".sv": - lic, err = prefix(tmpl, data, "", "// ", "") + lic, err = executeTemplate(tmpl, data, "", "// ", "") case ".py", ".sh", ".yaml", ".yml", ".dockerfile", "dockerfile", ".rb", "gemfile", ".tcl", ".bzl": - lic, err = prefix(tmpl, data, "", "# ", "") + lic, err = executeTemplate(tmpl, data, "", "# ", "") case ".el", ".lisp": - lic, err = prefix(tmpl, data, "", ";; ", "") + lic, err = executeTemplate(tmpl, data, "", ";; ", "") case ".erl": - lic, err = prefix(tmpl, data, "", "% ", "") + lic, err = executeTemplate(tmpl, data, "", "% ", "") case ".hs", ".sql", ".sdl": - lic, err = prefix(tmpl, data, "", "-- ", "") + lic, err = executeTemplate(tmpl, data, "", "-- ", "") case ".html", ".xml", ".vue": - lic, err = prefix(tmpl, data, "") + lic, err = executeTemplate(tmpl, data, "") case ".php": - lic, err = prefix(tmpl, data, "", "// ", "") + lic, err = executeTemplate(tmpl, data, "", "// ", "") case ".ml", ".mli", ".mll", ".mly": - lic, err = prefix(tmpl, data, "(**", " ", "*)") + lic, err = executeTemplate(tmpl, data, "(**", " ", "*)") } return lic, err } @@ -308,5 +331,6 @@ func hasLicense(b []byte) bool { n = len(b) } return bytes.Contains(bytes.ToLower(b[:n]), []byte("copyright")) || - bytes.Contains(bytes.ToLower(b[:n]), []byte("mozilla public")) + bytes.Contains(bytes.ToLower(b[:n]), []byte("mozilla public")) || + bytes.Contains(bytes.ToLower(b[:n]), []byte("SPDX-License-Identifier")) } diff --git a/testdata/custom.tpl b/testdata/custom.tpl new file mode 100644 index 0000000..5b1bc46 --- /dev/null +++ b/testdata/custom.tpl @@ -0,0 +1,3 @@ +Copyright {{.Year}} {{.Holder}} + +Custom License Template diff --git a/tmpl.go b/tmpl.go index b8540eb..e12bcc8 100644 --- a/tmpl.go +++ b/tmpl.go @@ -19,27 +19,69 @@ import ( "bytes" "fmt" "html/template" + "io/ioutil" "strings" "unicode" ) -var licenseTemplate = make(map[string]*template.Template) - -func init() { - licenseTemplate["apache"] = template.Must(template.New("").Parse(tmplApache)) - licenseTemplate["mit"] = template.Must(template.New("").Parse(tmplMIT)) - licenseTemplate["bsd"] = template.Must(template.New("").Parse(tmplBSD)) - licenseTemplate["mpl"] = template.Must(template.New("").Parse(tmplMPL)) +var licenseTemplate = map[string]string{ + "Apache-2.0": tmplApache, + "MIT": tmplMIT, + "bsd": tmplBSD, + "MPL-2.0": tmplMPL, } -type copyrightData struct { - Year string - Holder string +// maintain backwards compatibility by mapping legacy license types to their +// SPDX equivalents. +var legacyLicenseTypes = map[string]string{ + "apache": "Apache-2.0", + "mit": "MIT", + "mpl": "MPL-2.0", } -// prefix will execute a license template t with data d +// licenseData specifies the data used to fill out a license template. +type licenseData struct { + Year string // Copyright year(s). + Holder string // Name of the copyright holder. + SPDXID string // SPDX Identifier +} + +// fetchTemplate returns the license template for the specified license and +// optional templateFile. If templateFile is provided, the license is read +// from the specified file. Otherwise, a template is loaded for the specified +// license, if recognized. +func fetchTemplate(license string, templateFile string, spdx spdxFlag) (string, error) { + var t string + if spdx == spdxOnly { + t = tmplSPDX + } else if templateFile != "" { + d, err := ioutil.ReadFile(templateFile) + if err != nil { + return "", fmt.Errorf("license file: %w", err) + } + + t = string(d) + } else { + t = licenseTemplate[license] + if t == "" { + if spdx == spdxOn { + // unknown license, but SPDX headers requested + t = tmplSPDX + } else { + return "", fmt.Errorf("unknown license: %q. Include the '-s' flag to request SPDX style headers using this license.", license) + } + } else if spdx == spdxOn { + // append spdx headers to recognized license + t = t + spdxSuffix + } + } + + return t, nil +} + +// executeTemplate will execute a license template t with data d // and prefix the result with top, middle and bottom. -func prefix(t *template.Template, d *copyrightData, top, mid, bot string) ([]byte, error) { +func executeTemplate(t *template.Template, d licenseData, top, mid, bot string) ([]byte, error) { var buf bytes.Buffer if err := t.Execute(&buf, d); err != nil { return nil, err @@ -99,3 +141,8 @@ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.` const tmplMPL = `This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/.` + +const tmplSPDX = `{{ if and .Year .Holder }}Copyright {{.Year}} {{.Holder}} +{{ end }}SPDX-License-Identifier: {{.SPDXID}}` + +const spdxSuffix = "\n\nSPDX-License-Identifier: {{.SPDXID}}" diff --git a/tmpl_test.go b/tmpl_test.go new file mode 100644 index 0000000..4ed972c --- /dev/null +++ b/tmpl_test.go @@ -0,0 +1,140 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "errors" + "html/template" + "os" + "testing" +) + +func init() { + // ensure that pre-defined templates must parse + template.Must(template.New("").Parse(tmplApache)) + template.Must(template.New("").Parse(tmplMIT)) + template.Must(template.New("").Parse(tmplBSD)) + template.Must(template.New("").Parse(tmplMPL)) +} + +func TestFetchTemplate(t *testing.T) { + tests := []struct { + description string // test case description + license string // license passed to fetchTemplate + templateFile string // templatefile passed to fetchTemplate + spdx spdxFlag // spdx value passed to fetchTemplate + wantTemplate string // expected returned template + wantErr error // expected returned error + }{ + // custom template files + { + "non-existant template file", + "", + "/does/not/exist", + spdxOff, + "", + os.ErrNotExist, + }, + { + "custom template file", + "", + "testdata/custom.tpl", + spdxOff, + "Copyright {{.Year}} {{.Holder}}\n\nCustom License Template\n", + nil, + }, + + { + "unknown license", + "unknown", + "", + spdxOff, + "", + errors.New(`unknown license: "unknown". Include the '-s' flag to request SPDX style headers using this license.`), + }, + + // pre-defined license templates, no SPDX + { + "apache license template", + "Apache-2.0", + "", + spdxOff, + tmplApache, + nil, + }, + { + "mit license template", + "MIT", + "", + spdxOff, + tmplMIT, + nil, + }, + { + "bsd license template", + "bsd", + "", + spdxOff, + tmplBSD, + nil, + }, + { + "mpl license template", + "MPL-2.0", + "", + spdxOff, + tmplMPL, + nil, + }, + + // SPDX variants + { + "apache license template with SPDX added", + "Apache-2.0", + "", + spdxOn, + tmplApache + spdxSuffix, + nil, + }, + { + "apache license template with SPDX only", + "Apache-2.0", + "", + spdxOnly, + tmplSPDX, + nil, + }, + { + "unknown license with SPDX only", + "unknown", + "", + spdxOnly, + tmplSPDX, + nil, + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + tpl, err := fetchTemplate(tt.license, tt.templateFile, tt.spdx) + if tt.wantErr != nil && (err == nil || (!errors.Is(err, tt.wantErr) && err.Error() != tt.wantErr.Error())) { + t.Fatalf("fetchTemplate(%q, %q) returned error: %#v, want %#v", tt.license, tt.templateFile, err, tt.wantErr) + } + if tpl != tt.wantTemplate { + t.Errorf("fetchTemplate(%q, %q) returned template: %q, want %q", tt.license, tt.templateFile, tpl, tt.wantTemplate) + } + }) + } +}