diff --git a/main.go b/main.go index e9ad11a..0263907 100644 --- a/main.go +++ b/main.go @@ -277,12 +277,15 @@ func fileHasLicense(path string) (bool, error) { return hasLicense(b) || isGenerated(b), nil } +// licenseHeader populates the provided license template with data, and returns +// it with the proper prefix for the file type specified by path. The file does +// not need to actually exist, only its name is used to determine the prefix. func licenseHeader(path string, tmpl *template.Template, data licenseData) ([]byte, error) { var lic []byte var err error - switch fileExtension(path) { - default: - return nil, nil + base := strings.ToLower(filepath.Base(path)) + + switch fileExtension(base) { case ".c", ".h", ".gv", ".java", ".scala", ".kt", ".kts": lic, err = executeTemplate(tmpl, data, "/*", " * ", " */") case ".js", ".mjs", ".cjs", ".jsx", ".tsx", ".css", ".scss", ".sass", ".tf", ".ts": @@ -307,11 +310,13 @@ func licenseHeader(path string, tmpl *template.Template, data licenseData) ([]by return lic, err } +// fileExtension returns the file extension of name, or the full name if there +// is no extension. func fileExtension(name string) string { if v := filepath.Ext(name); v != "" { - return strings.ToLower(v) + return v } - return strings.ToLower(filepath.Base(name)) + return name } var head = []string{