diff --git a/main.go b/main.go index 90a0445..352d497 100644 --- a/main.go +++ b/main.go @@ -135,16 +135,16 @@ func walk(ch chan<- *file, start string) { func addLicense(path string, fmode os.FileMode, tmpl *template.Template, data *copyrightData) error { var lic []byte var err error - switch filepath.Ext(path) { + switch fileExtension(path) { default: return nil case ".c", ".h": lic, err = prefix(tmpl, data, "/*", " * ", " */") case ".js", ".jsx", ".tsx", ".css", ".tf": lic, err = prefix(tmpl, data, "/**", " * ", " */") - case ".cc", ".cpp", ".cs", ".go", ".hh", ".hpp", ".java", ".m", ".mm", ".proto", ".rs", ".scala", ".swift", ".dart": + case ".cc", ".cpp", ".cs", ".go", ".hh", ".hpp", ".java", ".m", ".mm", ".proto", ".rs", ".scala", ".swift", ".dart", ".groovy": lic, err = prefix(tmpl, data, "", "// ", "") - case ".py", ".sh", ".yaml", ".yml": + case ".py", ".sh", ".yaml", ".yml", ".dockerfile", "dockerfile", ".rb", "gemfile": lic, err = prefix(tmpl, data, "", "# ", "") case ".el", ".lisp": lic, err = prefix(tmpl, data, "", ";; ", "") @@ -178,10 +178,19 @@ func addLicense(path string, fmode os.FileMode, tmpl *template.Template, data *c return ioutil.WriteFile(path, b, fmode) } +func fileExtension(name string) string { + if v := filepath.Ext(name); v != "" { + return strings.ToLower(v) + } + return strings.ToLower(filepath.Base(name)) +} + var head = []string{ - "#!", // shell script - "