diff --git a/main.go b/main.go index 16310ba..143b797 100644 --- a/main.go +++ b/main.go @@ -58,6 +58,14 @@ var ( checkonly = flag.Bool("check", false, "check only mode: verify presence of license headers and exit with non-zero code if missing") ) +func init() { + flag.Usage = func() { + fmt.Fprintln(os.Stderr, helpText) + flag.PrintDefaults() + } + flag.Var(&skipExtensionFlags, "skip", "To skip files to check/add the header file, for example: -skip rb -skip go") +} + type skipExtensionFlag []string func (i *skipExtensionFlag) String() string { @@ -70,40 +78,24 @@ func (i *skipExtensionFlag) Set(value string) error { } func main() { - flag.Usage = func() { - fmt.Fprintln(os.Stderr, helpText) - flag.PrintDefaults() - } - flag.Var(&skipExtensionFlags, "skip", "To skip files to check/add the header file, for example: -skip rb -skip go") flag.Parse() if flag.NArg() == 0 { flag.Usage() os.Exit(1) } - data := ©rightData{ + data := licenseData{ Year: *year, Holder: *holder, } - var t *template.Template - if *licensef != "" { - d, err := ioutil.ReadFile(*licensef) - if err != nil { - log.Printf("license file: %v", err) - os.Exit(1) - } - t, err = template.New("").Parse(string(d)) - if err != nil { - log.Printf("license file: %v", err) - os.Exit(1) - } - } else { - t = licenseTemplate[*license] - if t == nil { - log.Printf("unknown license: %s", *license) - os.Exit(1) - } + tpl, err := fetchTemplate(*license, *licensef) + if err != nil { + log.Fatal(err) + } + t, err := template.New("").Parse(tpl) + if err != nil { + log.Fatal(err) } // process at most 1000 files in parallel @@ -189,7 +181,7 @@ func walk(ch chan<- *file, start string) { // addLicense add a license to the file if missing. // // It returns true if the file was updated. -func addLicense(path string, fmode os.FileMode, tmpl *template.Template, data *copyrightData) (bool, error) { +func addLicense(path string, fmode os.FileMode, tmpl *template.Template, data licenseData) (bool, error) { var lic []byte var err error lic, err = licenseHeader(path, tmpl, data) @@ -227,32 +219,32 @@ func fileHasLicense(path string) (bool, error) { return hasLicense(b) || isGenerated(b), nil } -func licenseHeader(path string, tmpl *template.Template, data *copyrightData) ([]byte, error) { +func licenseHeader(path string, tmpl *template.Template, data licenseData) ([]byte, error) { var lic []byte var err error switch fileExtension(path) { default: return nil, nil case ".c", ".h", ".gv": - lic, err = prefix(tmpl, data, "/*", " * ", " */") + lic, err = executeTemplate(tmpl, data, "/*", " * ", " */") case ".js", ".mjs", ".cjs", ".jsx", ".tsx", ".css", ".scss", ".sass", ".tf", ".ts": - lic, err = prefix(tmpl, data, "/**", " * ", " */") + lic, err = executeTemplate(tmpl, data, "/**", " * ", " */") case ".cc", ".cpp", ".cs", ".go", ".hh", ".hpp", ".java", ".m", ".mm", ".proto", ".rs", ".scala", ".swift", ".dart", ".groovy", ".kt", ".kts", ".v", ".sv": - lic, err = prefix(tmpl, data, "", "// ", "") + lic, err = executeTemplate(tmpl, data, "", "// ", "") case ".py", ".sh", ".yaml", ".yml", ".dockerfile", "dockerfile", ".rb", "gemfile", ".tcl", ".bzl": - lic, err = prefix(tmpl, data, "", "# ", "") + lic, err = executeTemplate(tmpl, data, "", "# ", "") case ".el", ".lisp": - lic, err = prefix(tmpl, data, "", ";; ", "") + lic, err = executeTemplate(tmpl, data, "", ";; ", "") case ".erl": - lic, err = prefix(tmpl, data, "", "% ", "") + lic, err = executeTemplate(tmpl, data, "", "% ", "") case ".hs", ".sql", ".sdl": - lic, err = prefix(tmpl, data, "", "-- ", "") + lic, err = executeTemplate(tmpl, data, "", "-- ", "") case ".html", ".xml", ".vue": - lic, err = prefix(tmpl, data, "") + lic, err = executeTemplate(tmpl, data, "") case ".php": - lic, err = prefix(tmpl, data, "", "// ", "") + lic, err = executeTemplate(tmpl, data, "", "// ", "") case ".ml", ".mli", ".mll", ".mly": - lic, err = prefix(tmpl, data, "(**", " ", "*)") + lic, err = executeTemplate(tmpl, data, "(**", " ", "*)") } return lic, err } diff --git a/testdata/custom.tpl b/testdata/custom.tpl new file mode 100644 index 0000000..5b1bc46 --- /dev/null +++ b/testdata/custom.tpl @@ -0,0 +1,3 @@ +Copyright {{.Year}} {{.Holder}} + +Custom License Template diff --git a/tmpl.go b/tmpl.go index b8540eb..47dda73 100644 --- a/tmpl.go +++ b/tmpl.go @@ -19,27 +19,50 @@ import ( "bytes" "fmt" "html/template" + "io/ioutil" "strings" "unicode" ) -var licenseTemplate = make(map[string]*template.Template) - -func init() { - licenseTemplate["apache"] = template.Must(template.New("").Parse(tmplApache)) - licenseTemplate["mit"] = template.Must(template.New("").Parse(tmplMIT)) - licenseTemplate["bsd"] = template.Must(template.New("").Parse(tmplBSD)) - licenseTemplate["mpl"] = template.Must(template.New("").Parse(tmplMPL)) +var licenseTemplate = map[string]string{ + "apache": tmplApache, + "mit": tmplMIT, + "bsd": tmplBSD, + "mpl": tmplMPL, } -type copyrightData struct { - Year string - Holder string +// licenseData specifies the data used to fill out a license template. +type licenseData struct { + Year string // Copyright year(s). + Holder string // Name of the copyright holder. } -// prefix will execute a license template t with data d +// fetchTemplate returns the license template for the specified license and +// optional templateFile. If templateFile is provided, the license is read +// from the specified file. Otherwise, a template is loaded for the specified +// license, if recognized. +func fetchTemplate(license string, templateFile string) (string, error) { + var t string + if templateFile != "" { + d, err := ioutil.ReadFile(templateFile) + if err != nil { + return "", fmt.Errorf("license file: %w", err) + } + + t = string(d) + } else { + t = licenseTemplate[license] + if t == "" { + return "", fmt.Errorf("unknown license: %q", license) + } + } + + return t, nil +} + +// executeTemplate will execute a license template t with data d // and prefix the result with top, middle and bottom. -func prefix(t *template.Template, d *copyrightData, top, mid, bot string) ([]byte, error) { +func executeTemplate(t *template.Template, d licenseData, top, mid, bot string) ([]byte, error) { var buf bytes.Buffer if err := t.Execute(&buf, d); err != nil { return nil, err diff --git a/tmpl_test.go b/tmpl_test.go new file mode 100644 index 0000000..be156fd --- /dev/null +++ b/tmpl_test.go @@ -0,0 +1,102 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "errors" + "html/template" + "os" + "testing" +) + +func init() { + // ensure that pre-defined templates must parse + template.Must(template.New("").Parse(tmplApache)) + template.Must(template.New("").Parse(tmplMIT)) + template.Must(template.New("").Parse(tmplBSD)) + template.Must(template.New("").Parse(tmplMPL)) +} + +func TestFetchTemplate(t *testing.T) { + tests := []struct { + description string // test case description + license string // license passed to fetchTemplate + templateFile string // templatefile passed to fetchTemplate + wantTemplate string // expected returned template + wantErr error // expected returned error + }{ + { + "non-existant template file", + "", + "/does/not/exist", + "", + os.ErrNotExist, + }, + { + "custom template file", + "", + "testdata/custom.tpl", + "Copyright {{.Year}} {{.Holder}}\n\nCustom License Template\n", + nil, + }, + { + "unknown license", + "unknown", + "", + "", + errors.New(`unknown license: "unknown"`), + }, + { + "apache license template", + "apache", + "", + tmplApache, + nil, + }, + { + "mit license template", + "mit", + "", + tmplMIT, + nil, + }, + { + "bsd license template", + "bsd", + "", + tmplBSD, + nil, + }, + { + "mpl license template", + "mpl", + "", + tmplMPL, + nil, + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + tpl, err := fetchTemplate(tt.license, tt.templateFile) + if tt.wantErr != nil && (err == nil || (!errors.Is(err, tt.wantErr) && err.Error() != tt.wantErr.Error())) { + t.Fatalf("fetchTemplate(%q, %q) returned error: %#v, want %#v", tt.license, tt.templateFile, err, tt.wantErr) + } + if tpl != tt.wantTemplate { + t.Errorf("fetchTemplate(%q, %q) returned template: %q, want %q", tt.license, tt.templateFile, tpl, tt.wantTemplate) + } + }) + } +}