diff --git a/main.go b/main.go index 839600e..90a0445 100644 --- a/main.go +++ b/main.go @@ -20,6 +20,7 @@ import ( "bytes" "flag" "fmt" + "html/template" "io/ioutil" "log" "os" @@ -44,9 +45,10 @@ Flags: ` var ( - holder = flag.String("c", "Google LLC", "copyright holder") - license = flag.String("l", "apache", "license type: apache, bsd, mit") - year = flag.Int("y", time.Now().Year(), "year") + holder = flag.String("c", "Google LLC", "copyright holder") + license = flag.String("l", "apache", "license type: apache, bsd, mit") + licensef = flag.String("f", "", "license file") + year = flag.Int("y", time.Now().Year(), "year") ) func main() { @@ -65,6 +67,26 @@ func main() { Holder: *holder, } + var t *template.Template + if *licensef != "" { + d, err := ioutil.ReadFile(*licensef) + if err != nil { + log.Printf("license file: %v", err) + os.Exit(1) + } + t, err = template.New("").Parse(string(d)) + if err != nil { + log.Printf("license file: %v", err) + os.Exit(1) + } + } else { + t = licenseTemplate[*license] + if t == nil { + log.Printf("unknown license: %s", *license) + os.Exit(1) + } + } + // process at most 1000 files in parallel ch := make(chan *file, 1000) done := make(chan struct{}) @@ -73,7 +95,7 @@ func main() { for f := range ch { wg.Add(1) go func(f *file) { - err := addLicense(f.path, f.mode, *license, data) + err := addLicense(f.path, f.mode, t, data) if err != nil { log.Printf("%s: %v", f.path, err) } @@ -110,30 +132,30 @@ func walk(ch chan<- *file, start string) { }) } -func addLicense(path string, fmode os.FileMode, typ string, data *copyrightData) error { +func addLicense(path string, fmode os.FileMode, tmpl *template.Template, data *copyrightData) error { var lic []byte var err error switch filepath.Ext(path) { default: return nil case ".c", ".h": - lic, err = prefix(typ, data, "/*", " * ", " */") + lic, err = prefix(tmpl, data, "/*", " * ", " */") case ".js", ".jsx", ".tsx", ".css", ".tf": - lic, err = prefix(typ, data, "/**", " * ", " */") + lic, err = prefix(tmpl, data, "/**", " * ", " */") case ".cc", ".cpp", ".cs", ".go", ".hh", ".hpp", ".java", ".m", ".mm", ".proto", ".rs", ".scala", ".swift", ".dart": - lic, err = prefix(typ, data, "", "// ", "") + lic, err = prefix(tmpl, data, "", "// ", "") case ".py", ".sh", ".yaml", ".yml": - lic, err = prefix(typ, data, "", "# ", "") + lic, err = prefix(tmpl, data, "", "# ", "") case ".el", ".lisp": - lic, err = prefix(typ, data, "", ";; ", "") + lic, err = prefix(tmpl, data, "", ";; ", "") case ".erl": - lic, err = prefix(typ, data, "", "% ", "") + lic, err = prefix(tmpl, data, "", "% ", "") case ".hs", ".sql": - lic, err = prefix(typ, data, "", "-- ", "") + lic, err = prefix(tmpl, data, "", "-- ", "") case ".html", ".xml": - lic, err = prefix(typ, data, "") + lic, err = prefix(tmpl, data, "") case ".php": - lic, err = prefix(typ, data, "") + lic, err = prefix(tmpl, data, "") } if err != nil || lic == nil { return err diff --git a/tmpl.go b/tmpl.go index 6119c1c..6ac2fbc 100644 --- a/tmpl.go +++ b/tmpl.go @@ -36,14 +36,9 @@ type copyrightData struct { Holder string } -// prefix will execute a license template l with data d +// prefix will execute a license template t with data d // and prefix the result with top, middle and bottom. -func prefix(l string, d *copyrightData, top, mid, bot string) ([]byte, error) { - t := licenseTemplate[l] - if t == nil { - return nil, fmt.Errorf("unknown license: %s", l) - } - +func prefix(t *template.Template, d *copyrightData, top, mid, bot string) ([]byte, error) { var buf bytes.Buffer if err := t.Execute(&buf, d); err != nil { return nil, err