Skip to content

Commit 9848cd6

Browse files
committed
Add api.GenerateOptions.Write and api.GenerateOptions.Diff
The two new boolean options let api.Generate cover the writefiles loop and the diff comparison that previously lived in cmd. The compile command becomes Generate with neither flag set, generate maps to Write: true, and diff maps to Diff: true. While simplifying GenerateOptions: * Drop MutateConfig — tests now express config mutations by writing a temporary configuration file via writeMutatedConfig and pointing GenerateOptions.File at it. The mutated config is parsed (always to v2 shape), forced to version "2", and round-tripped via yaml. * Drop DisableProcessPlugins from the API surface; we will revisit how to express that constraint. * Add MarshalJSON/MarshalYAML to AnalyzerDatabase so the parsed Config round-trips through yaml.Marshal cleanly, which is what the new test helper relies on. cmd/diff.go is gone and cmd/generate.go is left with only the helpers (readConfig, parse, printFileErr) other cmd commands still use. https://claude.ai/code/session_01RCzB2JR5Y5ScFDUmwcxGVZ
1 parent eda87b8 commit 9848cd6

8 files changed

Lines changed: 252 additions & 208 deletions

File tree

internal/api/diff.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package api
2+
3+
import (
4+
"bufio"
5+
"bytes"
6+
"context"
7+
"errors"
8+
"fmt"
9+
"io"
10+
"os"
11+
"path/filepath"
12+
"runtime/trace"
13+
"sort"
14+
"strings"
15+
16+
"github.com/cubicdaiya/gonp"
17+
)
18+
19+
func writeFiles(ctx context.Context, files map[string]string, stderr io.Writer) error {
20+
defer trace.StartRegion(ctx, "writefiles").End()
21+
for filename, source := range files {
22+
if err := os.MkdirAll(filepath.Dir(filename), 0755); err != nil {
23+
fmt.Fprintf(stderr, "%s: %s\n", filename, err)
24+
return err
25+
}
26+
if err := os.WriteFile(filename, []byte(source), 0644); err != nil {
27+
fmt.Fprintf(stderr, "%s: %s\n", filename, err)
28+
return err
29+
}
30+
}
31+
return nil
32+
}
33+
34+
func diffFiles(ctx context.Context, dir string, files map[string]string, stderr io.Writer) error {
35+
defer trace.StartRegion(ctx, "checkfiles").End()
36+
var errored bool
37+
38+
keys := make([]string, 0, len(files))
39+
for k := range files {
40+
keys = append(keys, k)
41+
}
42+
sort.Strings(keys)
43+
44+
for _, filename := range keys {
45+
source := files[filename]
46+
if _, err := os.Stat(filename); errors.Is(err, os.ErrNotExist) {
47+
errored = true
48+
continue
49+
}
50+
existing, err := os.ReadFile(filename)
51+
if err != nil {
52+
errored = true
53+
fmt.Fprintf(stderr, "%s: %s\n", filename, err)
54+
continue
55+
}
56+
d := gonp.New(getLines(existing), getLines([]byte(source)))
57+
d.Compose()
58+
uniHunks := filterHunks(d.UnifiedHunks())
59+
60+
if len(uniHunks) > 0 {
61+
errored = true
62+
fmt.Fprintf(stderr, "--- a%s\n", strings.TrimPrefix(filename, dir))
63+
fmt.Fprintf(stderr, "+++ b%s\n", strings.TrimPrefix(filename, dir))
64+
d.FprintUniHunks(stderr, uniHunks)
65+
}
66+
}
67+
if errored {
68+
return errors.New("diff found")
69+
}
70+
return nil
71+
}
72+
73+
func getLines(f []byte) []string {
74+
fp := bytes.NewReader(f)
75+
scanner := bufio.NewScanner(fp)
76+
lines := make([]string, 0)
77+
for scanner.Scan() {
78+
lines = append(lines, scanner.Text())
79+
}
80+
return lines
81+
}
82+
83+
func filterHunks[T gonp.Elem](uniHunks []gonp.UniHunk[T]) []gonp.UniHunk[T] {
84+
var out []gonp.UniHunk[T]
85+
for i, uniHunk := range uniHunks {
86+
var changed bool
87+
for _, e := range uniHunk.GetChanges() {
88+
switch e.GetType() {
89+
case gonp.SesDelete:
90+
changed = true
91+
case gonp.SesAdd:
92+
changed = true
93+
}
94+
}
95+
if changed {
96+
out = append(out, uniHunks[i])
97+
}
98+
}
99+
return out
100+
}

internal/api/generate.go

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package api
22

33
import (
44
"context"
5-
"errors"
65
"fmt"
76
"io"
87
"path/filepath"
@@ -13,10 +12,6 @@ import (
1312
"github.com/sqlc-dev/sqlc/internal/config"
1413
)
1514

16-
// errPluginProcessDisabled is returned when the configuration uses a process
17-
// plugin but the caller has disabled them via GenerateOptions.DisableProcessPlugins.
18-
var errPluginProcessDisabled = errors.New("plugin: process-based plugins disabled via SQLCDEBUG=processplugins=0")
19-
2015
// GenerateOptions controls a single Generate invocation.
2116
type GenerateOptions struct {
2217
// Dir is the working directory used to resolve the config file and any
@@ -30,14 +25,14 @@ type GenerateOptions struct {
3025
// Stderr receives diagnostic output. If nil, output is discarded.
3126
Stderr io.Writer
3227

33-
// DisableProcessPlugins, when true, causes Generate to fail if the
34-
// configuration uses a process-based plugin. The sqlc CLI sets this from
35-
// SQLCDEBUG=processplugins=0.
36-
DisableProcessPlugins bool
28+
// Write, when true, writes the generated files to disk after a successful
29+
// generate. Failures are reported via GenerateResult.Errors.
30+
Write bool
3731

38-
// MutateConfig is called after the configuration is parsed but before it is
39-
// validated. It is intended for tests.
40-
MutateConfig func(*config.Config)
32+
// Diff, when true, compares each generated file against any existing file
33+
// on disk and writes a unified diff for differences to Stderr. If any
34+
// differences are found, an error is appended to GenerateResult.Errors.
35+
Diff bool
4136
}
4237

4338
// GenerateResult is the outcome of a Generate call. Files maps absolute output
@@ -69,9 +64,6 @@ func Generate(ctx context.Context, opts GenerateOptions) GenerateResult {
6964
res.Errors = append(res.Errors, err)
7065
return res
7166
}
72-
if opts.MutateConfig != nil {
73-
opts.MutateConfig(conf)
74-
}
7567

7668
base := filepath.Base(configPath)
7769
if err := config.Validate(conf); err != nil {
@@ -80,14 +72,6 @@ func Generate(ctx context.Context, opts GenerateOptions) GenerateResult {
8072
return res
8173
}
8274

83-
if opts.DisableProcessPlugins {
84-
if err := validateProcessPluginsDisabled(conf); err != nil {
85-
fmt.Fprintf(stderr, "error validating %s: %s\n", base, err)
86-
res.Errors = append(res.Errors, err)
87-
return res
88-
}
89-
}
90-
9175
g := &generator{
9276
dir: opts.Dir,
9377
output: map[string]string{},
@@ -99,16 +83,20 @@ func Generate(ctx context.Context, opts GenerateOptions) GenerateResult {
9983
}
10084

10185
res.Files = g.output
102-
return res
103-
}
10486

105-
func validateProcessPluginsDisabled(cfg *config.Config) error {
106-
for _, plugin := range cfg.Plugins {
107-
if plugin.Process != nil {
108-
return errPluginProcessDisabled
87+
if opts.Write {
88+
if err := writeFiles(ctx, res.Files, stderr); err != nil {
89+
res.Errors = append(res.Errors, err)
10990
}
11091
}
111-
return nil
92+
93+
if opts.Diff {
94+
if err := diffFiles(ctx, opts.Dir, res.Files, stderr); err != nil {
95+
res.Errors = append(res.Errors, err)
96+
}
97+
}
98+
99+
return res
112100
}
113101

114102
type generator struct {

internal/cmd/cmd.go

Lines changed: 16 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
package cmd
22

33
import (
4-
"bufio"
5-
"bytes"
64
"context"
75
"errors"
86
"fmt"
@@ -12,11 +10,11 @@ import (
1210
"path/filepath"
1311
"runtime/trace"
1412

15-
"github.com/cubicdaiya/gonp"
1613
"github.com/spf13/cobra"
1714
"github.com/spf13/pflag"
1815
"gopkg.in/yaml.v3"
1916

17+
"github.com/sqlc-dev/sqlc/internal/api"
2018
"github.com/sqlc-dev/sqlc/internal/config"
2119
"github.com/sqlc-dev/sqlc/internal/debug"
2220
"github.com/sqlc-dev/sqlc/internal/info"
@@ -191,21 +189,15 @@ var genCmd = &cobra.Command{
191189
defer trace.StartRegion(cmd.Context(), "generate").End()
192190
stderr := cmd.ErrOrStderr()
193191
dir, name := getConfigPath(stderr, cmd.Flag("file"))
194-
output, err := Generate(cmd.Context(), dir, name, &Options{
195-
Env: ParseEnv(cmd),
192+
res := api.Generate(cmd.Context(), api.GenerateOptions{
193+
Dir: dir,
194+
File: name,
196195
Stderr: stderr,
196+
Write: true,
197197
})
198-
if err != nil {
198+
if len(res.Errors) > 0 {
199199
os.Exit(1)
200200
}
201-
defer trace.StartRegion(cmd.Context(), "writefiles").End()
202-
for filename, source := range output {
203-
os.MkdirAll(filepath.Dir(filename), 0755)
204-
if err := os.WriteFile(filename, []byte(source), 0644); err != nil {
205-
fmt.Fprintf(stderr, "%s: %s\n", filename, err)
206-
return err
207-
}
208-
}
209201
return nil
210202
},
211203
}
@@ -217,58 +209,32 @@ var checkCmd = &cobra.Command{
217209
defer trace.StartRegion(cmd.Context(), "compile").End()
218210
stderr := cmd.ErrOrStderr()
219211
dir, name := getConfigPath(stderr, cmd.Flag("file"))
220-
_, err := Generate(cmd.Context(), dir, name, &Options{
221-
Env: ParseEnv(cmd),
212+
res := api.Generate(cmd.Context(), api.GenerateOptions{
213+
Dir: dir,
214+
File: name,
222215
Stderr: stderr,
223216
})
224-
if err != nil {
217+
if len(res.Errors) > 0 {
225218
os.Exit(1)
226219
}
227220
return nil
228221
},
229222
}
230223

231-
func getLines(f []byte) []string {
232-
fp := bytes.NewReader(f)
233-
scanner := bufio.NewScanner(fp)
234-
lines := make([]string, 0)
235-
for scanner.Scan() {
236-
lines = append(lines, scanner.Text())
237-
}
238-
return lines
239-
}
240-
241-
func filterHunks[T gonp.Elem](uniHunks []gonp.UniHunk[T]) []gonp.UniHunk[T] {
242-
var out []gonp.UniHunk[T]
243-
for i, uniHunk := range uniHunks {
244-
var changed bool
245-
for _, e := range uniHunk.GetChanges() {
246-
switch e.GetType() {
247-
case gonp.SesDelete:
248-
changed = true
249-
case gonp.SesAdd:
250-
changed = true
251-
}
252-
}
253-
if changed {
254-
out = append(out, uniHunks[i])
255-
}
256-
}
257-
return out
258-
}
259-
260224
var diffCmd = &cobra.Command{
261225
Use: "diff",
262226
Short: "Compare the generated files to the existing files",
263227
RunE: func(cmd *cobra.Command, args []string) error {
264228
defer trace.StartRegion(cmd.Context(), "diff").End()
265229
stderr := cmd.ErrOrStderr()
266230
dir, name := getConfigPath(stderr, cmd.Flag("file"))
267-
opts := &Options{
268-
Env: ParseEnv(cmd),
231+
res := api.Generate(cmd.Context(), api.GenerateOptions{
232+
Dir: dir,
233+
File: name,
269234
Stderr: stderr,
270-
}
271-
if err := Diff(cmd.Context(), dir, name, opts); err != nil {
235+
Diff: true,
236+
})
237+
if len(res.Errors) > 0 {
272238
os.Exit(1)
273239
}
274240
return nil

internal/cmd/diff.go

Lines changed: 0 additions & 59 deletions
This file was deleted.

0 commit comments

Comments
 (0)