Skip to content

Commit 87a0512

Browse files
committed
fix(tool): consolidate Structs and StructTypes in generator
1 parent 8c6de1e commit 87a0512

File tree

3 files changed

+122
-41
lines changed

3 files changed

+122
-41
lines changed

generator/autogen.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,9 @@ func (g *Generator) BuildDependencyGraph() *DependencyGraph {
109109
}
110110

111111
// Second pass: Build edges by analyzing struct fields
112-
for typeName, structType := range g.P.Structs {
113-
if structType.Fields == nil {
112+
for typeName, typeSpec := range g.P.StructTypes {
113+
structType, ok := typeSpec.Type.(*ast.StructType)
114+
if !ok || structType.Fields == nil {
114115
continue
115116
}
116117
for _, field := range structType.Fields.List {
@@ -151,13 +152,17 @@ func (g *Generator) extractTypeReferences(expr ast.Expr) []string {
151152
if ident, ok := t.X.(*ast.Ident); ok {
152153
// Try to find the type in our structs
153154
typeName := t.Sel.Name
154-
if _, exists := g.P.Structs[typeName]; exists {
155-
types = append(types, typeName)
155+
if ts, exists := g.P.StructTypes[typeName]; exists {
156+
if _, ok := ts.Type.(*ast.StructType); ok {
157+
types = append(types, typeName)
158+
}
156159
}
157160
// Also check with package prefix
158161
fullName := ident.Name + "." + typeName
159-
if _, exists := g.P.Structs[fullName]; exists {
160-
types = append(types, fullName)
162+
if ts, exists := g.P.StructTypes[fullName]; exists {
163+
if _, ok := ts.Type.(*ast.StructType); ok {
164+
types = append(types, fullName)
165+
}
161166
}
162167
}
163168

generator/generator.go

Lines changed: 81 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ func (g *Generator) Run() error {
156156
g.applyAutoGeneration(depGraph)
157157
}
158158

159+
// Build scanned types registry with GQL annotation metadata
160+
g.buildScannedTypesRegistry()
161+
159162
// Check if we have any namespaces defined
160163
hasNamespaces := len(g.P.TypeNamespaces) > 0 || len(g.P.EnumNamespaces) > 0
161164

@@ -295,13 +298,17 @@ func (g *Generator) buildDependencyOrder() []string {
295298
return
296299
}
297300
visited[n] = true
298-
st := g.P.Structs[n]
299-
if st == nil {
301+
typeSpec := g.P.StructTypes[n]
302+
st, ok := typeSpec.Type.(*ast.StructType)
303+
if !ok {
300304
return
301305
}
302306
for _, f := range st.Fields.List {
303307
ft := FieldTypeName(f.Type)
304-
if _, ok := g.P.Structs[ft]; ok {
308+
if ts, ok := g.P.StructTypes[ft]; ok {
309+
if _, ok := ts.Type.(*ast.StructType); !ok {
310+
continue
311+
}
305312
dfs(ft)
306313
}
307314
}
@@ -322,7 +329,7 @@ func (g *Generator) generateTypeContent(typeName string, typeSpec *ast.TypeSpec,
322329
if d.HasTypeDirective {
323330
slog.Debug("Generating type from directive", "type", typeName, "count", len(d.Types))
324331
for _, typeDef := range d.Types {
325-
typeContent := g.generateTypeFromDef(typeSpec, g.P.Structs[typeName], d, typeDef, ctx)
332+
typeContent := g.generateTypeFromDef(typeSpec, typeSpec.Type.(*ast.StructType), d, typeDef, ctx)
326333
if typeContent != "" {
327334
buf.WriteString(typeContent)
328335
}
@@ -334,7 +341,7 @@ func (g *Generator) generateTypeContent(typeName string, typeSpec *ast.TypeSpec,
334341
Name: typeName,
335342
Description: "",
336343
}
337-
typeContent := g.generateTypeFromDef(typeSpec, g.P.Structs[typeName], d, defaultTypeDef, ctx)
344+
typeContent := g.generateTypeFromDef(typeSpec, typeSpec.Type.(*ast.StructType), d, defaultTypeDef, ctx)
338345
if typeContent != "" {
339346
buf.WriteString(typeContent)
340347
}
@@ -352,7 +359,7 @@ func (g *Generator) generateInputContent(typeName string, typeSpec *ast.TypeSpec
352359
if d.HasInputDirective {
353360
slog.Debug("Generating input from directive", "type", typeName, "count", len(d.Inputs))
354361
for _, inputDef := range d.Inputs {
355-
inputContent := g.generateInputFromDef(typeSpec, g.P.Structs[typeName], d, inputDef, ctx)
362+
inputContent := g.generateInputFromDef(typeSpec, typeSpec.Type.(*ast.StructType), d, inputDef, ctx)
356363
if inputContent != "" {
357364
buf.WriteString(inputContent)
358365
}
@@ -364,7 +371,7 @@ func (g *Generator) generateInputContent(typeName string, typeSpec *ast.TypeSpec
364371
Name: typeName + "Input",
365372
Description: "",
366373
}
367-
inputContent := g.generateInputFromDef(typeSpec, g.P.Structs[typeName], d, defaultInputDef, ctx)
374+
inputContent := g.generateInputFromDef(typeSpec, typeSpec.Type.(*ast.StructType), d, defaultInputDef, ctx)
368375
if inputContent != "" {
369376
buf.WriteString(inputContent)
370377
}
@@ -534,7 +541,7 @@ func (g *Generator) generateByNamespace(orders []string) (map[string]string, err
534541
for _, typeName := range typeNames {
535542
typeDefs := items.types[typeName]
536543
typeSpec := g.P.StructTypes[typeName]
537-
structType := g.P.Structs[typeName]
544+
structType := typeSpec.Type.(*ast.StructType)
538545
d := ParseDirectives(typeSpec, g.P.TypeToDecl[typeName])
539546

540547
for _, typeDef := range typeDefs {
@@ -543,13 +550,11 @@ func (g *Generator) generateByNamespace(orders []string) (map[string]string, err
543550
buf.WriteString(typeContent)
544551
}
545552
}
546-
}
547-
548-
// Generate inputs for this namespace
553+
} // Generate inputs for this namespace
549554
slog.Debug("Generating inputs for namespace", "namespace", namespace, "count", len(items.inputs))
550555
for typeName, inputDefs := range items.inputs {
551556
typeSpec := g.P.StructTypes[typeName]
552-
structType := g.P.Structs[typeName]
557+
structType := typeSpec.Type.(*ast.StructType)
553558
d := ParseDirectives(typeSpec, g.P.TypeToDecl[typeName])
554559

555560
for _, inputDef := range inputDefs {
@@ -718,7 +723,7 @@ func (g *Generator) generateByNamespaceAndPackage(orders []string) (map[string]s
718723
Strategy: "namespace+package",
719724
Namespace: ns,
720725
}
721-
typeContent := g.generateTypeFromDef(typeSpec, g.P.Structs[typeName], d, typeDef, ctx)
726+
typeContent := g.generateTypeFromDef(typeSpec, typeSpec.Type.(*ast.StructType), d, typeDef, ctx)
722727
if typeContent != "" {
723728
buf.WriteString(typeContent)
724729
}
@@ -760,7 +765,7 @@ func (g *Generator) generateByNamespaceAndPackage(orders []string) (map[string]s
760765
Strategy: "namespace+package",
761766
Namespace: ns,
762767
}
763-
inputContent := g.generateInputFromDef(typeSpec, g.P.Structs[typeName], d, inputDef, ctx)
768+
inputContent := g.generateInputFromDef(typeSpec, typeSpec.Type.(*ast.StructType), d, inputDef, ctx)
764769
if inputContent != "" {
765770
buf.WriteString(inputContent)
766771
}
@@ -1006,7 +1011,7 @@ func (g *Generator) generatePackageFiles(orders []string) (map[string]string, er
10061011
for _, typeName := range typeNames {
10071012
typeDefs := items.types[typeName]
10081013
typeSpec := g.P.StructTypes[typeName]
1009-
structType := g.P.Structs[typeName]
1014+
structType := typeSpec.Type.(*ast.StructType)
10101015
d := ParseDirectives(typeSpec, g.P.TypeToDecl[typeName])
10111016

10121017
for _, typeDef := range typeDefs {
@@ -1015,13 +1020,11 @@ func (g *Generator) generatePackageFiles(orders []string) (map[string]string, er
10151020
buf.WriteString(typeContent)
10161021
}
10171022
}
1018-
}
1019-
1020-
// Generate inputs for this package
1023+
} // Generate inputs for this package
10211024
slog.Debug("Generating inputs for package", "package", pkgPath, "count", len(items.inputs))
10221025
for typeName, inputDefs := range items.inputs {
10231026
typeSpec := g.P.StructTypes[typeName]
1024-
structType := g.P.Structs[typeName]
1027+
structType := typeSpec.Type.(*ast.StructType)
10251028
d := ParseDirectives(typeSpec, g.P.TypeToDecl[typeName])
10261029

10271030
for _, inputDef := range inputDefs {
@@ -1673,15 +1676,18 @@ func (g *Generator) expandEmbeddedFieldNamed(f *ast.Field, d StructDirectives, i
16731676
}
16741677

16751678
// Look up the embedded struct in the parser
1676-
embeddedStruct, exists := g.P.Structs[embeddedTypeName]
1679+
typeSpec, exists := g.P.StructTypes[embeddedTypeName]
16771680
if !exists {
16781681
return "" // Embedded struct not found in parsed types
16791682
}
1683+
embeddedStruct, ok := typeSpec.Type.(*ast.StructType)
1684+
if !ok {
1685+
return "" // Not a struct type
1686+
}
16801687

16811688
// If generating for input and the embedded type should be auto-generated as input,
16821689
// mark it for generation
16831690
if forInput {
1684-
typeSpec := g.P.StructTypes[embeddedTypeName]
16851691
if typeSpec != nil {
16861692
genDecl := g.P.TypeToDecl[embeddedTypeName]
16871693
directives := ParseDirectives(typeSpec, genDecl)
@@ -1994,6 +2000,43 @@ func (g *Generator) extractBaseTypeName(graphQLType string) string {
19942000
}
19952001

19962002
// isTypeInScope checks if a type name exists in the parsed types (structs or enums)
2003+
// buildScannedTypesRegistry populates the ScannedTypes registry with metadata about all scanned types
2004+
func (g *Generator) buildScannedTypesRegistry() {
2005+
for typeName, typeSpec := range g.P.StructTypes {
2006+
if genDecl, hasDecl := g.P.TypeToDecl[typeName]; hasDecl {
2007+
directives := ParseDirectives(typeSpec, genDecl)
2008+
2009+
info := &ScannedTypeInfo{
2010+
TypeName: typeName,
2011+
HasTypeDirective: directives.HasTypeDirective,
2012+
HasInputDirective: directives.HasInputDirective,
2013+
GeneratedTypes: make([]string, 0),
2014+
GeneratedInputs: make([]string, 0),
2015+
}
2016+
2017+
// Collect generated type names from @gqlType annotations
2018+
for _, typeDef := range directives.Types {
2019+
if typeDef.Name != "" {
2020+
info.GeneratedTypes = append(info.GeneratedTypes, typeDef.Name)
2021+
} else {
2022+
info.GeneratedTypes = append(info.GeneratedTypes, typeName)
2023+
}
2024+
}
2025+
2026+
// Collect generated input names from @gqlInput annotations
2027+
for _, inputDef := range directives.Inputs {
2028+
if inputDef.Name != "" {
2029+
info.GeneratedInputs = append(info.GeneratedInputs, inputDef.Name)
2030+
} else {
2031+
info.GeneratedInputs = append(info.GeneratedInputs, typeName+"Input")
2032+
}
2033+
}
2034+
2035+
g.P.ScannedTypes[typeName] = info
2036+
}
2037+
}
2038+
}
2039+
19972040
func (g *Generator) isTypeInScope(typeName string) bool {
19982041
if typeName == "" {
19992042
return true
@@ -2004,7 +2047,23 @@ func (g *Generator) isTypeInScope(typeName string) bool {
20042047
return true
20052048
}
20062049

2007-
// Check if it's a struct type
2050+
// Check if this type name was generated from a scanned type with GQL annotations
2051+
for _, scannedInfo := range g.P.ScannedTypes {
2052+
// Check if typeName matches any generated type names
2053+
for _, genType := range scannedInfo.GeneratedTypes {
2054+
if genType == typeName {
2055+
return true
2056+
}
2057+
}
2058+
// Check if typeName matches any generated input names
2059+
for _, genInput := range scannedInfo.GeneratedInputs {
2060+
if genInput == typeName {
2061+
return true
2062+
}
2063+
}
2064+
}
2065+
2066+
// Check if it's a struct type (scanned but no annotations)
20082067
if _, exists := g.P.StructTypes[typeName]; exists {
20092068
return true
20102069
}

generator/parser.go

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ type EnumType struct {
3636
// Parser collects type specs and related AST nodes across a root dir
3737
type Parser struct {
3838
StructTypes map[string]*ast.TypeSpec
39-
Structs map[string]*ast.StructType
4039
PackageNames map[string]string
4140
PackagePaths map[string]string // Full import path for each type
4241
SourceFiles map[string]string // Source file path for each type (absolute OS path)
@@ -57,12 +56,23 @@ type Parser struct {
5756
EnumSourceFiles map[string]string // enum name -> source file path
5857
// Type parameters for generic types
5958
TypeParameters map[string][]string // type name -> parameter names (e.g., "Result" -> ["T"], "Map" -> ["K", "V"])
59+
// Scanned types registry - tracks all types we've scanned with their GQL annotations
60+
// Key: Go type name, Value: metadata about the scanned type
61+
ScannedTypes map[string]*ScannedTypeInfo
62+
}
63+
64+
// ScannedTypeInfo stores metadata about a scanned type
65+
type ScannedTypeInfo struct {
66+
TypeName string // Go type name
67+
HasTypeDirective bool // Has @gqlType annotation
68+
HasInputDirective bool // Has @gqlInput annotation
69+
GeneratedTypes []string // List of GraphQL type names generated from this struct (from @gqlType)
70+
GeneratedInputs []string // List of GraphQL input names generated from this struct (from @gqlInput)
6071
}
6172

6273
func NewParser() *Parser {
6374
return &Parser{
6475
StructTypes: make(map[string]*ast.TypeSpec),
65-
Structs: make(map[string]*ast.StructType),
6676
PackageNames: make(map[string]string),
6777
PackagePaths: make(map[string]string),
6878
SourceFiles: make(map[string]string),
@@ -74,6 +84,7 @@ func NewParser() *Parser {
7484
EnumNamespaces: make(map[string]string),
7585
EnumSourceFiles: make(map[string]string),
7686
TypeParameters: make(map[string][]string),
87+
ScannedTypes: make(map[string]*ScannedTypeInfo),
7788
}
7889
}
7990

@@ -150,10 +161,9 @@ func (p *Parser) parseFile(path string) error {
150161
}
151162

152163
// Check if it's a struct
153-
if s, ok := t.Type.(*ast.StructType); ok {
164+
if _, ok := t.Type.(*ast.StructType); ok {
154165
name := t.Name.Name
155166
p.StructTypes[name] = t
156-
p.Structs[name] = s
157167
p.PackageNames[name] = pkgName
158168
p.PackagePaths[name] = path
159169
p.SourceFiles[name] = path // Store source file path
@@ -342,11 +352,10 @@ func (p *Parser) GetPackageImportPath(typeName string, modelPath string) string
342352
if pkgIndex == -1 {
343353
// Package directory not found in path, use modelPath as-is
344354
return modelPath
345-
}
346-
347-
// Check if there are meaningful parent directories between module root and package
355+
} // Check if there are meaningful parent directories between module root and package
348356
// Look for structure like: internal/models, pkg/entities, api/v2/models, etc.
349357
var subPath []string
358+
stopAtNext := false
350359

351360
// Collect directories from package backward until we hit a likely module boundary
352361
for i := pkgIndex; i >= 0; i-- {
@@ -357,18 +366,26 @@ func (p *Parser) GetPackageImportPath(typeName string, modelPath string) string
357366
continue
358367
}
359368

369+
// If we should stop after this iteration
370+
if stopAtNext {
371+
break
372+
}
373+
360374
subPath = append([]string{part}, subPath...)
361375

362-
// Stop if we hit common module structure markers (but include them)
376+
// Stop if we hit common module structure markers (but include them in the path)
363377
if part == "internal" || part == "pkg" || part == "cmd" || part == "api" {
364378
break
365379
}
366-
}
367380

368-
// If we only have the package name itself, return modelPath as-is
369-
// This handles cases where modelPath already points to the complete package location
370-
if len(subPath) == 1 && subPath[0] == pkgName {
371-
return modelPath
381+
// Also check if parent is a common test/development directory name - these typically are at module root
382+
if i > 0 {
383+
parentPart := parts[i-1]
384+
// If parent is "dev", "test", "examples", etc., stop after including current part
385+
if parentPart == "dev" || parentPart == "test" || parentPart == "examples" || parentPart == "demo" {
386+
stopAtNext = true
387+
}
388+
}
372389
}
373390

374391
// Otherwise, append the sub-path to modelPath

0 commit comments

Comments
 (0)