Skip to content

Commit 9050e51

Browse files
committed
feat(tool): Custom scalars mapping support
1 parent f7db49e commit 9050e51

File tree

8 files changed

+428
-6
lines changed

8 files changed

+428
-6
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[![GoDoc](https://godoc.org/github.com/pablor21/gqlschemagen?status.svg)](https://godoc.org/github.com/pablor21/gqlschemagen)
2-
[![GitHub release](https://img.shields.io/github/release/pablor21/gqlschemagen.svg?v0.1.0)](https://img.shields.io/github/release/pablor21/gqlschemagen.svg?v0.1.0)
2+
[![GitHub release](https://img.shields.io/github/release/pablor21/gqlschemagen.svg)](https://github.com/pablor21/gqlschemagen/releases)
33
[![GitHub license](https://img.shields.io/badge/license-MIT-blue.svg)](https://raw.githubusercontent.com/pablor21/gqlschemagen/master/LICENSE)
4+
[![Go Report Card](https://goreportcard.com/badge/github.com/pablor21/gqlschemagen)](https://goreportcard.com/report/github.com/pablor21/gqlschemagen)
45

56
# GQLSchemaGen
67

generator/config.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ type Config struct {
124124
// These types are always considered "in scope" and won't trigger out-of-scope warnings
125125
KnownScalars []string `yaml:"known_scalars"`
126126

127+
// Scalars maps GraphQL scalar names to Go package types
128+
// Example: Scalars["ID"] = ScalarMapping{Model: ["github.com/google/uuid.UUID"]}
129+
// This allows mapping Go types like uuid.UUID to GraphQL scalars like ID
130+
Scalars map[string]ScalarMapping `yaml:"scalars"`
131+
127132
// Auto-generation configuration
128133
AutoGenerate AutoGenerateConfig `yaml:"auto_generate"`
129134

@@ -200,6 +205,13 @@ type AutoGenerateConfig struct {
200205
SuppressGenericTypeWarnings bool `yaml:"suppress_generic_type_warnings"`
201206
}
202207

208+
// ScalarMapping maps Go types to GraphQL scalars
209+
type ScalarMapping struct {
210+
// Model is a list of Go package paths that map to this scalar
211+
// Example: ["github.com/google/uuid.UUID", "github.com/gofrs/uuid.UUID"]
212+
Model []string `yaml:"model"`
213+
}
214+
203215
// CLIConfig contains CLI-specific configuration
204216
type CLIConfig struct {
205217
Watcher WatcherConfig `yaml:"watcher"`
@@ -563,3 +575,57 @@ func (c *Config) Validate() error {
563575

564576
return nil
565577
}
578+
579+
// GetScalarForGoType returns the GraphQL scalar name for a given Go type package path
580+
// Returns empty string if no mapping exists
581+
func (c *Config) GetScalarForGoType(goTypePath string) string {
582+
if c.Scalars == nil {
583+
return ""
584+
}
585+
586+
for scalarName, mapping := range c.Scalars {
587+
for _, modelPath := range mapping.Model {
588+
if modelPath == goTypePath {
589+
return scalarName
590+
}
591+
}
592+
}
593+
594+
return ""
595+
}
596+
597+
// IsBuiltInScalar returns true if the scalar is a built-in GraphQL scalar
598+
func IsBuiltInScalar(scalarName string) bool {
599+
builtIns := map[string]bool{
600+
"Int": true,
601+
"Float": true,
602+
"String": true,
603+
"Boolean": true,
604+
"ID": true,
605+
}
606+
return builtIns[scalarName]
607+
}
608+
609+
// GetUsedCustomScalars returns all custom scalar names used in scalar mappings
610+
// Excludes built-in GraphQL scalars (Int, Float, String, Boolean, ID) and known_scalars
611+
func (c *Config) GetUsedCustomScalars() []string {
612+
if c.Scalars == nil {
613+
return nil
614+
}
615+
616+
// Build a set of known scalars for fast lookup
617+
knownScalarSet := make(map[string]bool)
618+
for _, scalar := range c.KnownScalars {
619+
knownScalarSet[scalar] = true
620+
}
621+
622+
customScalars := make([]string, 0)
623+
for scalarName := range c.Scalars {
624+
// Skip built-in scalars and known scalars (assumed to be defined elsewhere)
625+
if !IsBuiltInScalar(scalarName) && !knownScalarSet[scalarName] {
626+
customScalars = append(customScalars, scalarName)
627+
}
628+
}
629+
630+
return customScalars
631+
}

generator/generator.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,6 @@ func (g *Generator) generateSingleFile(orders []string) (map[string]string, erro
907907

908908
if g.Config.SkipExisting && FileExists(outFile) {
909909
slog.Info("Skipping existing file", "file", outFile)
910-
fmt.Println("skip", outFile)
911910
return map[string]string{}, nil
912911
}
913912

@@ -925,6 +924,15 @@ func (g *Generator) generateSingleFile(orders []string) (map[string]string, erro
925924
Namespace: "",
926925
}
927926

927+
// Generate custom scalar declarations first (from scalar mappings)
928+
customScalars := g.Config.GetUsedCustomScalars()
929+
if len(customScalars) > 0 {
930+
slog.Debug("Generating custom scalar declarations", "count", len(customScalars))
931+
for _, scalarName := range customScalars {
932+
fmt.Fprintf(buf, "scalar %s\n\n", scalarName)
933+
}
934+
}
935+
928936
// Generate enums first
929937
slog.Debug("Generating enums", "count", len(g.P.EnumNames))
930938
for _, enumName := range g.P.EnumNames {
@@ -1028,6 +1036,19 @@ func (g *Generator) generatePackageFiles(orders []string) (map[string]string, er
10281036
return packages[pkgPath]
10291037
}
10301038

1039+
// Generate custom scalar declarations that will be added to the first package file
1040+
var scalarDeclarations string
1041+
customScalars := g.Config.GetUsedCustomScalars()
1042+
if len(customScalars) > 0 {
1043+
slog.Debug("Generating custom scalar declarations", "count", len(customScalars))
1044+
scalarBuf := &strings.Builder{}
1045+
for _, scalarName := range customScalars {
1046+
fmt.Fprintf(scalarBuf, "scalar %s\n", scalarName)
1047+
}
1048+
scalarBuf.WriteString("\n")
1049+
scalarDeclarations = scalarBuf.String()
1050+
}
1051+
10311052
// Group enums by package
10321053
for _, enumName := range g.P.EnumNames {
10331054
enumType := g.P.EnumTypes[enumName]
@@ -1120,6 +1141,12 @@ func (g *Generator) generatePackageFiles(orders []string) (map[string]string, er
11201141
buf = &strings.Builder{}
11211142
fileContents[outFile] = buf
11221143
slog.Debug("Creating buffer for package file", "file", outFile, "package", pkgPath)
1144+
1145+
// Add scalar declarations to the first file created
1146+
if scalarDeclarations != "" {
1147+
buf.WriteString(scalarDeclarations)
1148+
scalarDeclarations = "" // Only add once
1149+
}
11231150
}
11241151

11251152
// Create generation context for this package file
@@ -1200,6 +1227,19 @@ func (g *Generator) generateMultipleFiles(orders []string) (map[string]string, e
12001227
// Collect all content in memory first (map of file path -> content)
12011228
fileContents := make(map[string]*strings.Builder)
12021229

1230+
// Generate custom scalar declarations in a dedicated file
1231+
customScalars := g.Config.GetUsedCustomScalars()
1232+
if len(customScalars) > 0 {
1233+
slog.Debug("Generating custom scalar declarations", "count", len(customScalars))
1234+
scalarFile := filepath.Join(g.Config.Output, "_scalars"+g.Config.OutputFileExtension)
1235+
scalarBuf := &strings.Builder{}
1236+
for _, scalarName := range customScalars {
1237+
fmt.Fprintf(scalarBuf, "scalar %s\n", scalarName)
1238+
}
1239+
scalarBuf.WriteString("\n")
1240+
fileContents[scalarFile] = scalarBuf
1241+
}
1242+
12031243
// Generate enums first
12041244
slog.Debug("Generating enum files", "count", len(g.P.EnumNames))
12051245
for _, enumName := range g.P.EnumNames {

generator/parser.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ type Parser struct {
6565
pkgCache map[string]*packages.Package // dir path -> package info
6666
// External types loaded on-demand (not from scanned packages)
6767
ExternalTypes map[string]bool // type name -> true if loaded on-demand from external package
68+
// Import paths for package aliases (e.g., "uuid" -> "github.com/google/uuid")
69+
// This is per-file and gets updated during parsing
70+
fileImports map[string]string // package alias/name -> import path
6871
}
6972

7073
// ScannedTypeInfo stores metadata about a scanned type
@@ -95,6 +98,7 @@ func NewParser() *Parser {
9598
ScannedTypes: make(map[string]*ScannedTypeInfo),
9699
pkgCache: make(map[string]*packages.Package),
97100
ExternalTypes: make(map[string]bool),
101+
fileImports: make(map[string]string),
98102
}
99103
}
100104

@@ -159,6 +163,9 @@ func (p *Parser) parseFile(path string) error {
159163
// Extract file-level namespace from comments after package declaration
160164
fileNamespace := extractFileNamespace(f)
161165

166+
// Parse import statements to track package aliases
167+
p.parseImports(f)
168+
162169
// First pass: collect type declarations (structs and potential enums)
163170
for _, decl := range f.Decls {
164171
genDecl, ok := decl.(*ast.GenDecl)
@@ -463,6 +470,39 @@ func (p *Parser) GetPackageImportPathFromFile(filePath string, pkgName string, m
463470
return pkg.PkgPath
464471
}
465472

473+
// parseImports extracts import statements and maps package aliases to import paths
474+
// Accumulates imports from all files to handle types used across different files
475+
func (p *Parser) parseImports(f *ast.File) {
476+
for _, importSpec := range f.Imports {
477+
if importSpec.Path == nil {
478+
continue
479+
}
480+
481+
// Get the import path (remove quotes)
482+
importPath := strings.Trim(importSpec.Path.Value, "\"")
483+
484+
// Determine the package alias
485+
var pkgAlias string
486+
if importSpec.Name != nil {
487+
// Explicit alias: import foo "github.com/bar/baz"
488+
pkgAlias = importSpec.Name.Name
489+
if pkgAlias == "_" || pkgAlias == "." {
490+
// Skip blank imports and dot imports
491+
continue
492+
}
493+
} else {
494+
// No alias, use the last segment of the import path
495+
// "github.com/google/uuid" -> "uuid"
496+
parts := strings.Split(importPath, "/")
497+
pkgAlias = parts[len(parts)-1]
498+
}
499+
500+
// Store the mapping (overwrites if same alias used in different files with different imports,
501+
// but that would be a package name collision anyway)
502+
p.fileImports[pkgAlias] = importPath
503+
}
504+
}
505+
466506
// hasGqlEnumDirective checks if a GenDecl has @gqlEnum or @GqlEnum directive in its doc comments
467507
func hasGqlEnumDirective(decl *ast.GenDecl) bool {
468508
if decl.Doc == nil {

generator/scalar_mapping_test.go

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
package generator
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestScalarMapping(t *testing.T) {
8+
config := &Config{
9+
Scalars: map[string]ScalarMapping{
10+
"ID": {
11+
Model: []string{
12+
"github.com/google/uuid.UUID",
13+
"github.com/gofrs/uuid.UUID",
14+
},
15+
},
16+
"DateTime": {
17+
Model: []string{
18+
"time.Time",
19+
},
20+
},
21+
},
22+
}
23+
24+
tests := []struct {
25+
name string
26+
goTypePath string
27+
expectedName string
28+
}{
29+
{
30+
name: "Google UUID maps to ID",
31+
goTypePath: "github.com/google/uuid.UUID",
32+
expectedName: "ID",
33+
},
34+
{
35+
name: "Gofrs UUID maps to ID",
36+
goTypePath: "github.com/gofrs/uuid.UUID",
37+
expectedName: "ID",
38+
},
39+
{
40+
name: "time.Time maps to DateTime",
41+
goTypePath: "time.Time",
42+
expectedName: "DateTime",
43+
},
44+
{
45+
name: "Unmapped type returns empty",
46+
goTypePath: "github.com/example/CustomType",
47+
expectedName: "",
48+
},
49+
}
50+
51+
for _, tt := range tests {
52+
t.Run(tt.name, func(t *testing.T) {
53+
result := config.GetScalarForGoType(tt.goTypePath)
54+
if result != tt.expectedName {
55+
t.Errorf("GetScalarForGoType(%q) = %q, want %q", tt.goTypePath, result, tt.expectedName)
56+
}
57+
})
58+
}
59+
}
60+
61+
func TestIsBuiltInScalar(t *testing.T) {
62+
tests := []struct {
63+
name string
64+
scalarName string
65+
expected bool
66+
}{
67+
{"Int is built-in", "Int", true},
68+
{"Float is built-in", "Float", true},
69+
{"String is built-in", "String", true},
70+
{"Boolean is built-in", "Boolean", true},
71+
{"ID is built-in", "ID", true},
72+
{"DateTime is custom", "DateTime", false},
73+
{"UUID is custom", "UUID", false},
74+
{"JSON is custom", "JSON", false},
75+
}
76+
77+
for _, tt := range tests {
78+
t.Run(tt.name, func(t *testing.T) {
79+
result := IsBuiltInScalar(tt.scalarName)
80+
if result != tt.expected {
81+
t.Errorf("IsBuiltInScalar(%q) = %v, want %v", tt.scalarName, result, tt.expected)
82+
}
83+
})
84+
}
85+
}
86+
87+
func TestGetUsedCustomScalars(t *testing.T) {
88+
config := &Config{
89+
KnownScalars: []string{"DateTime", "Upload"},
90+
Scalars: map[string]ScalarMapping{
91+
"ID": {
92+
Model: []string{"github.com/google/uuid.UUID"},
93+
},
94+
"DateTime": {
95+
Model: []string{"time.Time"},
96+
},
97+
"Upload": {
98+
Model: []string{"github.com/99designs/gqlgen/graphql.Upload"},
99+
},
100+
"CustomScalar": {
101+
Model: []string{"github.com/example/pkg.CustomType"},
102+
},
103+
"Int": {
104+
Model: []string{"int64"},
105+
},
106+
},
107+
}
108+
109+
customScalars := config.GetUsedCustomScalars()
110+
111+
// Should only return CustomScalar
112+
// - ID and Int are built-in GraphQL scalars
113+
// - DateTime and Upload are in known_scalars
114+
expectedCount := 1 // Only CustomScalar
115+
if len(customScalars) != expectedCount {
116+
t.Errorf("GetUsedCustomScalars() returned %d scalars, want %d. Got: %v", len(customScalars), expectedCount, customScalars)
117+
}
118+
119+
// Check that CustomScalar is in the list
120+
found := false
121+
for _, scalar := range customScalars {
122+
if scalar == "CustomScalar" {
123+
found = true
124+
break
125+
}
126+
}
127+
if !found {
128+
t.Errorf("GetUsedCustomScalars() did not include 'CustomScalar'")
129+
}
130+
131+
// Check that built-in scalars and known_scalars are not included
132+
for _, scalar := range customScalars {
133+
if scalar == "ID" || scalar == "Int" {
134+
t.Errorf("GetUsedCustomScalars() should not include built-in scalar %q", scalar)
135+
}
136+
if scalar == "DateTime" || scalar == "Upload" {
137+
t.Errorf("GetUsedCustomScalars() should not include known_scalar %q", scalar)
138+
}
139+
}
140+
}

0 commit comments

Comments
 (0)