Skip to content

Commit 6b2d30d

Browse files
committed
Replace api.GenerateOptions Dir/File with Config io.Reader
The struct collapses to five fields: Config (io.Reader), Stderr, Write, Diff, InsecureProcessPluginNames. api.Generate parses the config from the reader and treats every relative path in it as relative to the current working directory. CLI: each command opens the config file, reads its bytes, parses it once to extract declared process-plugin names, then chdirs to the config's directory before invoking api.Generate. Single-process so chdir is fine. Tests: a new mutatedConfigBytes helper parses the test's sqlc.yaml, forces version "2", rewrites every schema/queries/output path to be absolute relative to the test directory, and re-encodes as YAML — so api.Generate works without knowing the source directory. Optional mutate callback applies extra changes (managed-db servers etc.) and also drops a temp file alongside the original for cmd.Vet which still takes a config path. cmd/process.go and cmd/vet.go now skip joining their dir parameter when the config-supplied path is already absolute. KNOWN ISSUE: TestReplay parse-error tests and the diff_output tests fail because the api now emits absolute paths in error messages and unified-diff labels (no config-dir context to strip). Either add a BaseDir hint back to GenerateOptions or update the affected test expectations to match. https://claude.ai/code/session_01RCzB2JR5Y5ScFDUmwcxGVZ
1 parent 963f626 commit 6b2d30d

10 files changed

Lines changed: 272 additions & 230 deletions

File tree

internal/api/config.go

Lines changed: 0 additions & 93 deletions
This file was deleted.

internal/api/diff.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"path/filepath"
1212
"runtime/trace"
1313
"sort"
14-
"strings"
1514

1615
"github.com/cubicdaiya/gonp"
1716
)
@@ -31,10 +30,12 @@ func writeFiles(ctx context.Context, files map[string]string, stderr io.Writer)
3130
return nil
3231
}
3332

34-
func diffFiles(ctx context.Context, dir string, files map[string]string, stderr io.Writer) error {
33+
func diffFiles(ctx context.Context, files map[string]string, stderr io.Writer) error {
3534
defer trace.StartRegion(ctx, "checkfiles").End()
3635
var errored bool
3736

37+
wd, _ := os.Getwd()
38+
3839
keys := make([]string, 0, len(files))
3940
for k := range files {
4041
keys = append(keys, k)
@@ -59,8 +60,14 @@ func diffFiles(ctx context.Context, dir string, files map[string]string, stderr
5960

6061
if len(uniHunks) > 0 {
6162
errored = true
62-
fmt.Fprintf(stderr, "--- a%s\n", strings.TrimPrefix(filename, dir))
63-
fmt.Fprintf(stderr, "+++ b%s\n", strings.TrimPrefix(filename, dir))
63+
label := filename
64+
if wd != "" {
65+
if rel, err := filepath.Rel(wd, filename); err == nil {
66+
label = "/" + rel
67+
}
68+
}
69+
fmt.Fprintf(stderr, "--- a%s\n", label)
70+
fmt.Fprintf(stderr, "+++ b%s\n", label)
6471
d.FprintUniHunks(stderr, uniHunks)
6572
}
6673
}

internal/api/generate.go

Lines changed: 61 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package api
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"io"
78
"path/filepath"
@@ -13,26 +14,24 @@ import (
1314
"github.com/sqlc-dev/sqlc/internal/config"
1415
)
1516

16-
// GenerateOptions controls a single Generate invocation.
17+
// GenerateOptions controls a single Generate invocation. Paths declared in the
18+
// configuration are resolved relative to the current working directory, so
19+
// callers wanting a different base directory should either pass absolute
20+
// paths in the config or os.Chdir before calling.
1721
type GenerateOptions struct {
18-
// Dir is the working directory used to resolve the config file and any
19-
// relative schema/query paths within it.
20-
Dir string
21-
22-
// File is the configuration filename to use, relative to Dir. When empty,
23-
// Generate looks for sqlc.yaml, sqlc.yml, or sqlc.json in Dir.
24-
File string
22+
// Config is the sqlc configuration as a YAML or JSON document. Required.
23+
Config io.Reader
2524

2625
// Stderr receives diagnostic output. If nil, output is discarded.
2726
Stderr io.Writer
2827

29-
// Write, when true, writes the generated files to disk after a successful
30-
// generate. Failures are reported via GenerateResult.Errors.
28+
// Write writes the generated files to disk after a successful generate.
29+
// Failures are reported via GenerateResult.Errors.
3130
Write bool
3231

33-
// Diff, when true, compares each generated file against any existing file
34-
// on disk and writes a unified diff for differences to Stderr. If any
35-
// differences are found, an error is appended to GenerateResult.Errors.
32+
// Diff compares each generated file against any existing file on disk and
33+
// writes a unified diff for differences to Stderr. If any differences are
34+
// found, an error is appended to GenerateResult.Errors.
3635
Diff bool
3736

3837
// InsecureProcessPluginNames is the allowlist of process-based plugin
@@ -45,9 +44,7 @@ type GenerateOptions struct {
4544
InsecureProcessPluginNames []string
4645
}
4746

48-
// GenerateResult is the outcome of a Generate call. Files maps absolute output
49-
// paths to file contents; callers are responsible for writing them to disk if
50-
// desired. Errors collects any errors encountered during code generation.
47+
// GenerateResult is the outcome of a Generate call.
5148
type GenerateResult struct {
5249
// Files maps absolute output paths to generated file contents.
5350
Files map[string]string
@@ -58,9 +55,7 @@ type GenerateResult struct {
5855
}
5956

6057
// Generate parses the sqlc configuration referenced by opts and runs every
61-
// configured codegen target. The returned GenerateResult always has a non-nil
62-
// Files map; the map is empty when generation fails before any files are
63-
// produced.
58+
// configured codegen target.
6459
func Generate(ctx context.Context, opts GenerateOptions) GenerateResult {
6560
stderr := opts.Stderr
6661
if stderr == nil {
@@ -69,15 +64,30 @@ func Generate(ctx context.Context, opts GenerateOptions) GenerateResult {
6964

7065
res := GenerateResult{Files: map[string]string{}}
7166

72-
configPath, conf, err := readConfig(stderr, opts.Dir, opts.File)
67+
if opts.Config == nil {
68+
err := errors.New("GenerateOptions.Config is required")
69+
fmt.Fprintln(stderr, err)
70+
res.Errors = append(res.Errors, err)
71+
return res
72+
}
73+
74+
conf, err := config.ParseConfig(opts.Config)
7375
if err != nil {
76+
switch err {
77+
case config.ErrMissingVersion:
78+
fmt.Fprint(stderr, errMessageNoVersion)
79+
case config.ErrUnknownVersion:
80+
fmt.Fprint(stderr, errMessageUnknownVersion)
81+
case config.ErrNoPackages:
82+
fmt.Fprint(stderr, errMessageNoPackages)
83+
}
84+
fmt.Fprintf(stderr, "error parsing config: %s\n", err)
7485
res.Errors = append(res.Errors, err)
7586
return res
7687
}
7788

78-
base := filepath.Base(configPath)
79-
if err := config.Validate(conf); err != nil {
80-
fmt.Fprintf(stderr, "error validating %s: %s\n", base, err)
89+
if err := config.Validate(&conf); err != nil {
90+
fmt.Fprintf(stderr, "error validating config: %s\n", err)
8191
res.Errors = append(res.Errors, err)
8292
return res
8393
}
@@ -88,18 +98,15 @@ func Generate(ctx context.Context, opts GenerateOptions) GenerateResult {
8898
}
8999
if !slices.Contains(opts.InsecureProcessPluginNames, plug.Name) {
90100
err := fmt.Errorf("process plugin %q is not in InsecureProcessPluginNames; refusing to run", plug.Name)
91-
fmt.Fprintf(stderr, "error validating %s: %s\n", base, err)
101+
fmt.Fprintf(stderr, "error validating config: %s\n", err)
92102
res.Errors = append(res.Errors, err)
93103
return res
94104
}
95105
}
96106

97-
g := &generator{
98-
dir: opts.Dir,
99-
output: map[string]string{},
100-
}
107+
g := &generator{output: map[string]string{}}
101108

102-
if err := processQuerySets(ctx, g, conf, opts.Dir, stderr); err != nil {
109+
if err := processQuerySets(ctx, g, &conf, stderr); err != nil {
103110
res.Errors = append(res.Errors, err)
104111
return res
105112
}
@@ -113,17 +120,31 @@ func Generate(ctx context.Context, opts GenerateOptions) GenerateResult {
113120
}
114121

115122
if opts.Diff {
116-
if err := diffFiles(ctx, opts.Dir, res.Files, stderr); err != nil {
123+
if err := diffFiles(ctx, res.Files, stderr); err != nil {
117124
res.Errors = append(res.Errors, err)
118125
}
119126
}
120127

121128
return res
122129
}
123130

131+
const errMessageNoVersion = `The configuration must have a version number.
132+
Set the version to 1 or 2 at the top of the config:
133+
134+
{
135+
"version": "1"
136+
...
137+
}
138+
`
139+
140+
const errMessageUnknownVersion = `The configuration has an invalid version number.
141+
The supported version can only be "1" or "2".
142+
`
143+
144+
const errMessageNoPackages = `No packages are configured`
145+
124146
type generator struct {
125147
m sync.Mutex
126-
dir string
127148
output map[string]string
128149
}
129150

@@ -162,23 +183,25 @@ func (g *generator) ProcessResult(ctx context.Context, combo config.CombinedSett
162183
files[file.Name] = string(file.Contents)
163184
}
164185
g.m.Lock()
186+
defer g.m.Unlock()
165187

166-
// out is specified by the user, not a plugin
167-
absout := filepath.Join(g.dir, out)
188+
absout, err := filepath.Abs(out)
189+
if err != nil {
190+
return err
191+
}
168192

169193
for n, source := range files {
170-
filename := filepath.Join(g.dir, out, n)
171-
// filepath.Join calls filepath.Clean which should remove all "..", but
172-
// double check to make sure
194+
filename, err := filepath.Abs(filepath.Join(out, n))
195+
if err != nil {
196+
return err
197+
}
173198
if strings.Contains(filename, "..") {
174199
return fmt.Errorf("invalid file output path: %s", filename)
175200
}
176-
// The output file must be contained inside the output directory
177201
if !strings.HasPrefix(filename, absout) {
178202
return fmt.Errorf("invalid file output path: %s", filename)
179203
}
180204
g.output[filename] = source
181205
}
182-
g.m.Unlock()
183206
return nil
184207
}

internal/api/parse.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"io"
7+
"os"
78
"path/filepath"
89
"runtime/trace"
910

@@ -14,15 +15,21 @@ import (
1415
"github.com/sqlc-dev/sqlc/internal/opts"
1516
)
1617

17-
func printFileErr(stderr io.Writer, dir string, fileErr *multierr.FileError) {
18-
filename, err := filepath.Rel(dir, fileErr.Filename)
18+
func printFileErr(stderr io.Writer, fileErr *multierr.FileError) {
19+
wd, err := os.Getwd()
1920
if err != nil {
20-
filename = fileErr.Filename
21+
wd = ""
22+
}
23+
filename := fileErr.Filename
24+
if wd != "" {
25+
if rel, err := filepath.Rel(wd, fileErr.Filename); err == nil {
26+
filename = rel
27+
}
2128
}
2229
fmt.Fprintf(stderr, "%s:%d:%d: %s\n", filename, fileErr.Line, fileErr.Column, fileErr.Err)
2330
}
2431

25-
func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) {
32+
func parse(ctx context.Context, name string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) {
2633
defer trace.StartRegion(ctx, "parse").End()
2734
c, err := compiler.NewCompiler(sql, combo, parserOpts)
2835
defer func() {
@@ -38,7 +45,7 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C
3845
fmt.Fprintf(stderr, "# package %s\n", name)
3946
if parserErr, ok := err.(*multierr.Error); ok {
4047
for _, fileErr := range parserErr.Errs() {
41-
printFileErr(stderr, dir, fileErr)
48+
printFileErr(stderr, fileErr)
4249
}
4350
} else {
4451
fmt.Fprintf(stderr, "error parsing schema: %s\n", err)
@@ -52,7 +59,7 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C
5259
fmt.Fprintf(stderr, "# package %s\n", name)
5360
if parserErr, ok := err.(*multierr.Error); ok {
5461
for _, fileErr := range parserErr.Errs() {
55-
printFileErr(stderr, dir, fileErr)
62+
printFileErr(stderr, fileErr)
5663
}
5764
} else {
5865
fmt.Fprintf(stderr, "error parsing queries: %s\n", err)

0 commit comments

Comments
 (0)