Implement -check flag

When this flag is used:

* The program never modifies any files
* If all files in the pattern contain a license, the program exits with
a zero exit code
* If at least one file in the pattern requires modification to include
license text, the program prints such files to STDOUT and exits with
a non-zero exit code
This commit is contained in:
Mithun Ayachit
2020-02-12 06:30:58 -06:00
parent c46413539e
commit 27146d5f03
3 changed files with 118 additions and 32 deletions

View File

@@ -18,6 +18,7 @@ to any file that already has one.
-f custom license file (no default) -f custom license file (no default)
-l license type: apache, bsd, mit (defaults to "apache") -l license type: apache, bsd, mit (defaults to "apache")
-y year (defaults to current year) -y year (defaults to current year)
-check check only mode: verify presence of license headers and exit with non-zero code if missing
The pattern argument can be provided multiple times, and may also refer The pattern argument can be provided multiple times, and may also refer
to single files. to single files.

104
main.go
View File

@@ -18,6 +18,7 @@ package main
import ( import (
"bytes" "bytes"
"errors"
"flag" "flag"
"fmt" "fmt"
"html/template" "html/template"
@@ -46,11 +47,12 @@ Flags:
` `
var ( var (
holder = flag.String("c", "Google LLC", "copyright holder") holder = flag.String("c", "Google LLC", "copyright holder")
license = flag.String("l", "apache", "license type: apache, bsd, mit") license = flag.String("l", "apache", "license type: apache, bsd, mit")
licensef = flag.String("f", "", "license file") licensef = flag.String("f", "", "license file")
year = flag.String("y", fmt.Sprint(time.Now().Year()), "copyright year(s)") year = flag.String("y", fmt.Sprint(time.Now().Year()), "copyright year(s)")
verbose = flag.Bool("v", false, "verbose mode: print the name of the files that are modified") verbose = flag.Bool("v", false, "verbose mode: print the name of the files that are modified")
checkonly = flag.Bool("check", false, "check only mode: verify presence of license headers and exit with non-zero code if missing")
) )
func main() { func main() {
@@ -97,13 +99,35 @@ func main() {
for f := range ch { for f := range ch {
f := f // https://golang.org/doc/faq#closures_and_goroutines f := f // https://golang.org/doc/faq#closures_and_goroutines
wg.Go(func() error { wg.Go(func() error {
modified, err := addLicense(f.path, f.mode, t, data) if *checkonly {
if err != nil { // Check if file extension is known
log.Printf("%s: %v", f.path, err) lic, err := licenseHeader(f.path, t, data)
return err if err != nil {
} log.Printf("%s: %v", f.path, err)
if *verbose && modified { return err
log.Printf("%s modified", f.path) }
if lic == nil { // Unknown fileExtension
return nil
}
// Check if file has a license
isMissingLicenseHeader, err := fileHasLicense(f.path)
if err != nil {
log.Printf("%s: %v", f.path, err)
return err
}
if isMissingLicenseHeader {
fmt.Printf("%s\n", f.path)
return errors.New("missing license header")
}
} else {
modified, err := addLicense(f.path, f.mode, t, data)
if err != nil {
log.Printf("%s: %v", f.path, err)
return err
}
if *verbose && modified {
log.Printf("%s modified", f.path)
}
} }
return nil return nil
}) })
@@ -142,11 +166,45 @@ func walk(ch chan<- *file, start string) {
} }
func addLicense(path string, fmode os.FileMode, tmpl *template.Template, data *copyrightData) (bool, error) { func addLicense(path string, fmode os.FileMode, tmpl *template.Template, data *copyrightData) (bool, error) {
var lic []byte
var err error
lic, err = licenseHeader(path, tmpl, data)
if err != nil || lic == nil {
return false, err
}
b, err := ioutil.ReadFile(path)
if err != nil || hasLicense(b) {
return false, err
}
line := hashBang(b)
if len(line) > 0 {
b = b[len(line):]
if line[len(line)-1] != '\n' {
line = append(line, '\n')
}
lic = append(line, lic...)
}
b = append(lic, b...)
return true, ioutil.WriteFile(path, b, fmode)
}
// fileHasLicense reports whether the file at path contains a license header.
func fileHasLicense(path string) (bool, error) {
b, err := ioutil.ReadFile(path)
if err != nil || hasLicense(b) {
return false, err
}
return true, nil
}
func licenseHeader(path string, tmpl *template.Template, data *copyrightData) ([]byte, error) {
var lic []byte var lic []byte
var err error var err error
switch fileExtension(path) { switch fileExtension(path) {
default: default:
return false, nil return nil, nil
case ".c", ".h": case ".c", ".h":
lic, err = prefix(tmpl, data, "/*", " * ", " */") lic, err = prefix(tmpl, data, "/*", " * ", " */")
case ".js", ".jsx", ".tsx", ".css", ".tf", ".ts": case ".js", ".jsx", ".tsx", ".css", ".tf", ".ts":
@@ -168,25 +226,7 @@ func addLicense(path string, fmode os.FileMode, tmpl *template.Template, data *c
case ".ml", ".mli", ".mll", ".mly": case ".ml", ".mli", ".mll", ".mly":
lic, err = prefix(tmpl, data, "(**", " ", "*)") lic, err = prefix(tmpl, data, "(**", " ", "*)")
} }
if err != nil || lic == nil { return lic, err
return false, err
}
b, err := ioutil.ReadFile(path)
if err != nil || hasLicense(b) {
return false, err
}
line := hashBang(b)
if len(line) > 0 {
b = b[len(line):]
if line[len(line)-1] != '\n' {
line = append(line, '\n')
}
lic = append(line, lic...)
}
b = append(lic, b...)
return true, ioutil.WriteFile(path, b, fmode)
} }
func fileExtension(name string) string { func fileExtension(name string) string {

View File

@@ -139,3 +139,48 @@ func TestReadErrors(t *testing.T) {
} }
run(t, "chmod", "0644", samplefile) run(t, "chmod", "0644", samplefile)
} }
func TestCheckSuccess(t *testing.T) {
if os.Getenv("RUNME") != "" {
main()
return
}
tmp := tempDir(t)
t.Logf("tmp dir: %s", tmp)
samplefile := filepath.Join(tmp, "file.c")
run(t, "cp", "testdata/expected/file.c", samplefile)
cmd := exec.Command(os.Args[0],
"-test.run=TestCheckSuccess",
"-l", "apache", "-c", "Google LLC", "-y", "2018",
"-check", samplefile,
)
cmd.Env = []string{"RUNME=1"}
if out, err := cmd.CombinedOutput(); err != nil {
t.Fatalf("%v\n%s", err, out)
}
}
func TestCheckFail(t *testing.T) {
if os.Getenv("RUNME") != "" {
main()
return
}
tmp := tempDir(t)
t.Logf("tmp dir: %s", tmp)
samplefile := filepath.Join(tmp, "file.c")
run(t, "cp", "testdata/initial/file.c", samplefile)
cmd := exec.Command(os.Args[0],
"-test.run=TestCheckFail",
"-l", "apache", "-c", "Google LLC", "-y", "2018",
"-check", samplefile,
)
cmd.Env = []string{"RUNME=1"}
out, err := cmd.CombinedOutput()
if err == nil {
t.Fatalf("TestCheckFail exited with a zero exit code.\n%s", out)
}
}