Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing and required fields #28

Merged
merged 2 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 97 additions & 24 deletions components/definition/definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
type Definition struct {
Type string `json:"type"`
Properties map[string]DefinitionProperties `json:"properties"`
Required []string `json:"required"`
}

// DefinitionProperties defines the details of a property within a Definition,
Expand All @@ -25,6 +26,9 @@ type DefinitionProperties struct {
Ref string `json:"$ref,omitempty"`
Items *DefinitionPropertiesItems `json:"items,omitempty"`
Example interface{} `json:"example,omitempty"`

// keep this info to fill Required fields later
IsRequired bool `json:"-"`
}

// DefinitionPropertiesItems specifies the type or reference of array items when
Expand Down Expand Up @@ -70,10 +74,38 @@ func (g DefinitionGenerator) CreateDefinition(t interface{}) {
properties = g.createStructDefinitions(reflectReturn)
}

// merge embedded struct fields with other fields
g.mergeEmbeddedStructFields(properties)

g.Definitions[definitionName] = Definition{
Type: "object",
Properties: properties,
Required: g.findRequiredFields(properties),
}
}

func (g DefinitionGenerator) mergeEmbeddedStructFields(properties map[string]DefinitionProperties) {
for k, v := range properties {
if k == "" && v.Ref != "" { // identify embedded structs
embeddedModelName, _ := strings.CutPrefix(v.Ref, "#/definitions/")
if def, ok := g.Definitions[embeddedModelName]; ok {
for propName, propValue := range def.Properties {
properties[propName] = propValue
}
delete(properties, "")
}
}
}
}

func (g DefinitionGenerator) findRequiredFields(properties map[string]DefinitionProperties) []string {
requiredFields := []string{}
for k, v := range properties {
if v.IsRequired {
requiredFields = append(requiredFields, k)
}
}
return requiredFields
}

func (g DefinitionGenerator) createStructDefinitions(structType reflect.Type) map[string]DefinitionProperties {
Expand All @@ -96,37 +128,66 @@ func (g DefinitionGenerator) createStructDefinitions(structType reflect.Type) ma
// if item type is array, create Definition for array element type
switch fieldType {
case "array":
if field.Type.Elem().Kind() == reflect.Struct {
if field.Type.Elem().Kind() == reflect.Pointer { // []*type
if field.Type.Elem().Elem().Kind() == reflect.Struct { // []*struct
properties[fieldJsonTag] = DefinitionProperties{
Example: fields.ExampleTag(field),
Type: fieldType,
Items: &DefinitionPropertiesItems{
Ref: fmt.Sprintf("#/definitions/%s", field.Type.Elem().Elem().String()),
},
IsRequired: g.isRequired(field),
}
if structType == field.Type.Elem() {
continue // prevent recursion
}
g.CreateDefinition(reflect.New(field.Type.Elem().Elem()).Elem().Interface())
} else { // []*other
itemType := fields.Type(field.Type.Elem().Elem().Kind().String())
properties[fieldJsonTag] = DefinitionProperties{
Example: fields.ExampleTag(field),
Type: fieldType,
Items: &DefinitionPropertiesItems{
Type: itemType,
},
IsRequired: g.isRequired(field),
}
}
} else if field.Type.Elem().Kind() == reflect.Struct { // []struct
properties[fieldJsonTag] = DefinitionProperties{
Example: fields.ExampleTag(field),
Type: fieldType,
Items: &DefinitionPropertiesItems{
Ref: fmt.Sprintf("#/definitions/%s", field.Type.Elem().String()),
},
IsRequired: g.isRequired(field),
}
if structType == field.Type.Elem() {
continue // prevent recursion
}
g.CreateDefinition(reflect.New(field.Type.Elem()).Elem().Interface())
} else {
} else { // []other
properties[fieldJsonTag] = DefinitionProperties{
Example: fields.ExampleTag(field),
Type: fieldType,
Items: &DefinitionPropertiesItems{
Type: fields.Type(field.Type.Elem().Kind().String()),
},
IsRequired: g.isRequired(field),
}
}

case "struct":
isRequiredField := g.isRequired(field)
if field.Type.String() == "time.Time" {
properties[fieldJsonTag] = g.timeProperty(field)
properties[fieldJsonTag] = g.timeProperty(field, isRequiredField)
} else if field.Type.String() == "time.Duration" {
properties[fieldJsonTag] = g.durationProperty(field)
properties[fieldJsonTag] = g.durationProperty(field, isRequiredField)
} else {
properties[fieldJsonTag] = DefinitionProperties{
Example: fields.ExampleTag(field),
Ref: fmt.Sprintf("#/definitions/%s", field.Type.String()),
Example: fields.ExampleTag(field),
Ref: fmt.Sprintf("#/definitions/%s", field.Type.String()),
IsRequired: isRequiredField,
}
g.CreateDefinition(reflect.New(field.Type).Elem().Interface())
}
Expand All @@ -140,11 +201,11 @@ func (g DefinitionGenerator) createStructDefinitions(structType reflect.Type) ma
}
if field.Type.Elem().Kind() == reflect.Struct {
if field.Type.Elem().String() == "time.Time" {
properties[fieldJsonTag] = g.timeProperty(field)
properties[fieldJsonTag] = g.timeProperty(field, false)
} else if field.Type.String() == "time.Duration" {
properties[fieldJsonTag] = g.durationProperty(field)
properties[fieldJsonTag] = g.durationProperty(field, false)
} else {
properties[fieldJsonTag] = g.refProperty(field)
properties[fieldJsonTag] = g.refProperty(field, false)
g.CreateDefinition(reflect.New(field.Type.Elem()).Elem().Interface())
}
} else if field.Type.Elem().Kind() == reflect.Array || field.Type.Elem().Kind() == reflect.Slice {
Expand Down Expand Up @@ -210,43 +271,55 @@ func (g DefinitionGenerator) createStructDefinitions(structType reflect.Type) ma
case "interface":
// TODO: Find a way to get real model of interface{}
properties[fieldJsonTag] = DefinitionProperties{
Example: fields.ExampleTag(field),
Type: "Ambiguous Type: interface{}",
Example: fields.ExampleTag(field),
Type: "Ambiguous Type: interface{}",
IsRequired: g.isRequired(field),
}
default:

default:
properties[fieldJsonTag] = g.defaultProperty(field)

}
}

return properties
}

func (g DefinitionGenerator) timeProperty(field reflect.StructField) DefinitionProperties {
func (g DefinitionGenerator) timeProperty(field reflect.StructField, required bool) DefinitionProperties {
return DefinitionProperties{
Example: fields.ExampleTag(field),
Type: "string",
Format: "date-time",
Example: fields.ExampleTag(field),
Type: "string",
Format: "date-time",
IsRequired: required,
}
}

func (g DefinitionGenerator) durationProperty(field reflect.StructField) DefinitionProperties {
func (g DefinitionGenerator) durationProperty(field reflect.StructField, required bool) DefinitionProperties {
return DefinitionProperties{
Example: fields.ExampleTag(field),
Type: "integer",
Example: fields.ExampleTag(field),
Type: "integer",
IsRequired: required,
}
}

func (g DefinitionGenerator) refProperty(field reflect.StructField) DefinitionProperties {
func (g DefinitionGenerator) refProperty(field reflect.StructField, required bool) DefinitionProperties {
return DefinitionProperties{
Example: fields.ExampleTag(field),
Ref: fmt.Sprintf("#/definitions/%s", field.Type.Elem().String()),
Example: fields.ExampleTag(field),
Ref: fmt.Sprintf("#/definitions/%s", field.Type.Elem().String()),
IsRequired: required,
}
}

func (g DefinitionGenerator) defaultProperty(field reflect.StructField) DefinitionProperties {
return DefinitionProperties{
Example: fields.ExampleTag(field),
Type: fields.Type(field.Type.Kind().String()),
Example: fields.ExampleTag(field),
Type: fields.Type(field.Type.Kind().String()),
IsRequired: g.isRequired(field),
}
}

func (g DefinitionGenerator) isRequired(field reflect.StructField) bool {
hasRequiredTag := fields.IsRequired(field)
hasOmitemptyTag := fields.IsOmitempty(field)
return hasRequiredTag || !hasOmitemptyTag
}
17 changes: 17 additions & 0 deletions components/fields/parsing.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,23 @@ func JsonTag(field reflect.StructField) string {
return jsonTag
}

// IsOmitempty extracts the 'json' struct tag's value of a struct field and returns if it has omitempty.
func IsOmitempty(field reflect.StructField) bool {
jsonTag := field.Tag.Get("json")
for _, part := range strings.Split(jsonTag, ",") {
if strings.TrimSpace(part) == "omitempty" {
return true
}
}
return false
}

// IsRequired extracts the 'required' struct tag's value of a struct field and returns true if required is true.
func IsRequired(field reflect.StructField) bool {
tagValue := field.Tag.Get("required")
return tagValue == "true"
}

// Type maps a string to its corresponding Swagger type according to the
// Swagger Specification version 2 data types (https://swagger.io/specification/v2/#data-types).
func Type(t string) string {
Expand Down
27 changes: 25 additions & 2 deletions components/tag/tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,37 @@ package tag

// https://swagger.io/specification/v2/#tagObject
type Tag struct {
Name string `json:"name"`
Description string `json:"description"`
ExternalDocs *ExternalDocs `json:"externalDocs,omitempty"`
}

type ExternalDocs struct {
Name string `json:"name"`
Description string `json:"description"`
}

type TagOpts func(*Tag)

func WithExternalDocs(name string, description string) TagOpts {
return func(t *Tag) {
t.ExternalDocs = &ExternalDocs{
Name: name,
Description: description,
}
}
}

// New returns a new Tag.
func New(name string, description string) Tag {
return Tag{
func New(name string, description string, opts ...TagOpts) Tag {
t := Tag{
Name: name,
Description: description,
}

for _, opt := range opts {
opt(&t)
}

return t
}
10 changes: 10 additions & 0 deletions example/fiber/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ func main() {
endpoint.WithConsume([]mime.MIME{mime.JSON}),
endpoint.WithSummary("this is a test summary"),
),
endpoint.New(
endpoint.POST,
"/product",
endpoint.WithTags("product"),
endpoint.WithBody(models.ProductPost{}),
endpoint.WithSuccessfulReturns([]response.Response{response.New(models.Product{}, "200", "OK")}),
endpoint.WithDescription(desc),
endpoint.WithProduce([]mime.MIME{mime.JSON, mime.XML}),
endpoint.WithConsume([]mime.MIME{mime.JSON}),
),
endpoint.New(
endpoint.GET,
"/product/{id}",
Expand Down
30 changes: 21 additions & 9 deletions example/models/products.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,27 @@ package models
import "time"

type Product struct {
Id uint64 `json:"id"`
Name string `json:"name"`
MerchantId uint64 `json:"merchant_id"`
CategoryId *uint64 `json:"category_id,omitempty"`
Tags []uint64 `json:"tags"`
Images []string `json:"image_ids"`
Sizes []Sizes `json:"sizes"`
SaleDate time.Time `json:"sale_date"`
EndDate *time.Time `json:"end_date"`
Id uint64 `json:"id"`
Name string `json:"name"`
MerchantId uint64 `json:"merchant_id"`
CategoryId *uint64 `json:"category_id,omitempty"`
Tags []uint64 `json:"tags"`
Images []*string `json:"image_ids"`
ImagesPtr *[]string `json:"image_ids_ptr"`
Sizes []Sizes `json:"sizes"`
SizePtrs []*Sizes `json:"size_ptrs"`
SaleDate time.Time `json:"sale_date"`
EndDate *time.Time `json:"end_date"`
Complex ComplexSuccessfulResponse `json:"complex"`
Interface interface{} `json:"interface"`
OmitEmpty string `json:"omitemptytest,omitempty"`
RequiredField interface{} `json:"required_field,omitempty" required:"true"`
EmbeddedStruct EmbeddedStruct `json:"embedded_struct"`
}

type EmbeddedStruct struct {
Sizes
OtherField int `json:"other_field"`
}

type Sizes struct {
Expand Down
2 changes: 1 addition & 1 deletion generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func TestSwaggerGeneration(t *testing.T) {
got.AddEndpoints(tc.endpoints)
got.generateSwaggerJson()

if diff := cmp.Diff(want, got, cmpopts.IgnoreFields(Swagger{}, "endpoints"), cmpopts.IgnoreFields(definition.DefinitionProperties{}, "Example")); diff != "" {
if diff := cmp.Diff(want, got, cmpopts.IgnoreFields(Swagger{}, "endpoints"), cmpopts.IgnoreFields(definition.DefinitionProperties{}, "Example", "IsRequired")); diff != "" {
t.Errorf("JsonSwagger() mismatch (-expected +got):\n%s", diff)
}
})
Expand Down
25 changes: 15 additions & 10 deletions swagger.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type Swagger struct {
// https://swagger.io/specification/v2/#info-object
type Info struct {
Title string `json:"title"`
Description string `json:"description"`
Version string `json:"version"`
TermsOfService string `json:"termsOfService,omitempty"`
Contact *Contact `json:"contact,omitempty"`
Expand Down Expand Up @@ -84,12 +85,14 @@ type License struct {

// Config struct represents the configuration for Swagger documentation.
type Config struct {
Title string // title of the Swagger documentation
Version string // version of the Swagger documentation
Host string // host URL for the API
Path string // path to the Swagger JSON file
License *License // license information for the Swagger documentation
Contact *Contact // contact information for the Swagger documentation
Title string // title of the Swagger documentation
Version string // version of the Swagger documentation
Description string // description of the Swagger documentation
Host string // host URL for the API
Path string // path to the Swagger JSON file
License *License // license information for the Swagger documentation
Contact *Contact // contact information for the Swagger documentation
TermsOfService string // term of service information for the Swagger documentation
}

// buildSwagger creates a new swagger instance with the given title, version, and optional arguments.
Expand All @@ -110,10 +113,12 @@ func buildSwagger(c Config) (swagger *Swagger) {
swagger = &Swagger{
Swagger: "2.0",
Info: Info{
Title: c.Title,
Version: c.Version,
License: c.License,
Contact: c.Contact,
Title: c.Title,
Description: c.Description,
Version: c.Version,
License: c.License,
Contact: c.Contact,
TermsOfService: c.TermsOfService,
},
Paths: make(map[string]map[string]endpoint.JsonEndPoint),
BasePath: c.Path,
Expand Down
Loading