diff --git a/example/feature_demo/demo_multi_file.pb.go b/example/feature_demo/demo_multi_file.pb.go index 6580325c..157a2b63 100644 --- a/example/feature_demo/demo_multi_file.pb.go +++ b/example/feature_demo/demo_multi_file.pb.go @@ -26,7 +26,7 @@ type ExternalChild struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` // string id + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` } func (x *ExternalChild) Reset() { diff --git a/example/feature_demo/demo_multi_file.proto b/example/feature_demo/demo_multi_file.proto index 70a942e7..a9b68561 100644 --- a/example/feature_demo/demo_multi_file.proto +++ b/example/feature_demo/demo_multi_file.proto @@ -9,7 +9,7 @@ option go_package = "github.com/infobloxopen/protoc-gen-gorm/example/feature_dem option (gorm.opts) = { ormable: true, }; - string id = 1; // string id + string id = 1; } message BlogPost { diff --git a/example/feature_demo/demo_types.pb.go b/example/feature_demo/demo_types.pb.go index 439384e6..0811da32 100644 --- a/example/feature_demo/demo_types.pb.go +++ b/example/feature_demo/demo_types.pb.go @@ -254,7 +254,7 @@ func (x *TestTypes) GetSeveralValues() []*types.JSONValue { return nil } -// TypeWithID demonstrates some basic assocation behavior +// TypeWithID demonstrates some basic association behavior type TypeWithID struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache diff --git a/example/feature_demo/demo_types.proto b/example/feature_demo/demo_types.proto index ffe45f40..c6ce01a6 100644 --- a/example/feature_demo/demo_types.proto +++ b/example/feature_demo/demo_types.proto @@ -76,7 +76,7 @@ message TestTypes { repeated gorm.types.JSONValue several_values = 14; } -// TypeWithID demonstrates some basic assocation behavior +// TypeWithID demonstrates some basic association behavior message TypeWithID { // Again we use the 'ormable' option, but also include an extra field // using the 'include' option. Any number of fields can be defined this way @@ -119,7 +119,7 @@ message TypeWithID { // MultiaccountTypeWithID demonstrates the generated multi-account support message MultiaccountTypeWithID { - // here we use the multi_account option to generate auth integration in + // here we use the multi-account option to generate auth integration in // the ORM layer, and an assumed "account_id" column option (gorm.opts) = { ormable: true, diff --git a/example/postgres_arrays/postgres_arrays.pb.go b/example/postgres_arrays/postgres_arrays.pb.go index 89217152..5fac23ef 100644 --- a/example/postgres_arrays/postgres_arrays.pb.go +++ b/example/postgres_arrays/postgres_arrays.pb.go @@ -26,6 +26,7 @@ type Example struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields + // id for example Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` Description string `protobuf:"bytes,2,opt,name=description,proto3" json:"description,omitempty"` ArrayOfBools []bool `protobuf:"varint,20,rep,packed,name=array_of_bools,json=arrayOfBools,proto3" json:"array_of_bools,omitempty"` diff --git a/example/postgres_arrays/postgres_arrays.pb.gorm.go b/example/postgres_arrays/postgres_arrays.pb.gorm.go index a6f7e18e..0c4d94e9 100644 --- a/example/postgres_arrays/postgres_arrays.pb.gorm.go +++ b/example/postgres_arrays/postgres_arrays.pb.gorm.go @@ -4,7 +4,6 @@ import ( context "context" fmt "fmt" gateway "github.com/infobloxopen/atlas-app-toolkit/gateway" - gorm1 "github.com/infobloxopen/atlas-app-toolkit/gorm" errors "github.com/infobloxopen/protoc-gen-gorm/errors" pq "github.com/lib/pq" field_mask "google.golang.org/genproto/protobuf/field_mask" @@ -165,9 +164,6 @@ func DefaultReadExample(ctx context.Context, in *Example, db *gorm.DB) (*Example return nil, err } } - if db, err = gorm1.ApplyFieldSelection(ctx, db, nil, &ExampleORM{}); err != nil { - return nil, err - } if hook, ok := interface{}(&ormObj).(ExampleORMWithBeforeReadFind); ok { if db, err = hook.BeforeReadFind(ctx, db); err != nil { return nil, err @@ -443,10 +439,6 @@ func DefaultListExample(ctx context.Context, db *gorm.DB) ([]*Example, error) { return nil, err } } - db, err = gorm1.ApplyCollectionOperators(ctx, db, &ExampleORM{}, &Example{}, nil, nil, nil, nil) - if err != nil { - return nil, err - } if hook, ok := interface{}(&ormObj).(ExampleORMWithBeforeListFind); ok { if db, err = hook.BeforeListFind(ctx, db); err != nil { return nil, err diff --git a/example/postgres_arrays/postgres_arrays.proto b/example/postgres_arrays/postgres_arrays.proto index 52b4d7eb..ccc03868 100644 --- a/example/postgres_arrays/postgres_arrays.proto +++ b/example/postgres_arrays/postgres_arrays.proto @@ -9,6 +9,7 @@ option go_package = "github.com/infobloxopen/protoc-gen-gorm/example/postgres_ar message Example { option (gorm.opts) = {ormable: true}; + // id for example string id = 1 [(gorm.field).tag = {type: "uuid" primary_key: true}]; string description = 2; repeated bool array_of_bools = 20; diff --git a/example/user/user.pb.go b/example/user/user.pb.go index 81e6a42e..0b9d66ab 100644 --- a/example/user/user.pb.go +++ b/example/user/user.pb.go @@ -583,6 +583,7 @@ type Department struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields + // department name and id are composite primary keys Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` Id int64 `protobuf:"varint,2,opt,name=id,proto3" json:"id,omitempty"` } diff --git a/example/user/user.pb.gorm.go b/example/user/user.pb.gorm.go index 8e1fa5a5..b0f54174 100644 --- a/example/user/user.pb.gorm.go +++ b/example/user/user.pb.gorm.go @@ -3036,6 +3036,9 @@ func DefaultReadTask(ctx context.Context, in *Task, db *gorm.DB) (*Task, error) if err != nil { return nil, err } + if ormObj.Id == nil || *ormObj.Id == "" { + return nil, errors.EmptyIdError + } if hook, ok := interface{}(&ormObj).(TaskORMWithBeforeReadApplyQuery); ok { if db, err = hook.BeforeReadApplyQuery(ctx, db); err != nil { return nil, err @@ -3077,6 +3080,9 @@ func DefaultDeleteTask(ctx context.Context, in *Task, db *gorm.DB) error { if err != nil { return err } + if ormObj.Id == nil || *ormObj.Id == "" { + return errors.EmptyIdError + } if hook, ok := interface{}(&ormObj).(TaskORMWithBeforeDelete_); ok { if db, err = hook.BeforeDelete_(ctx, db); err != nil { return err @@ -3439,10 +3445,10 @@ func DefaultDeleteDepartment(ctx context.Context, in *Department, db *gorm.DB) e if err != nil { return err } - if ormObj.Name == "" { + if ormObj.Id == 0 { return errors.EmptyIdError } - if ormObj.Id == 0 { + if ormObj.Name == "" { return errors.EmptyIdError } if hook, ok := interface{}(&ormObj).(DepartmentORMWithBeforeDelete_); ok { @@ -3509,7 +3515,7 @@ func DefaultListDepartment(ctx context.Context, db *gorm.DB) ([]*Department, err } } db = db.Where(&ormObj) - db = db.Order("name, id") + db = db.Order("id, name") ormResponse := []DepartmentORM{} if err := db.Find(&ormResponse).Error; err != nil { return nil, err diff --git a/example/user/user.proto b/example/user/user.proto index e5fae109..05f516de 100644 --- a/example/user/user.proto +++ b/example/user/user.proto @@ -98,6 +98,7 @@ message Department { option (gorm.opts) = { ormable: true, }; + // department name and id are composite primary keys string name = 1 [(gorm.field).tag = {primary_key: true}]; int64 id = 2 [(gorm.field).tag = {primary_key: true}]; } diff --git a/plugin/plugin.go b/plugin/plugin.go index 38d899d3..82bd0c33 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -98,6 +98,11 @@ var optionalTypes = map[string]string{ "bool": "*bool", } +type pkFieldObjs struct { + name string + field *Field +} + const ( protoTypeTimestamp = "Timestamp" // last segment, first will be *google_protobufX protoTypeDuration = "Duration" @@ -579,20 +584,35 @@ func (b *ORMBuilder) findPrimaryKey(ormable *OrmableType) (string, *Field) { panic("no primary_key") } -func (b *ORMBuilder) getPrimaryKeys(ormable *OrmableType) map[string]*Field { - mapPK := make(map[string]*Field) +// getPrimaryKeys returns a sorted list of primary key field objects +func (b *ORMBuilder) getPrimaryKeys(ormable *OrmableType) []pkFieldObjs { + var ( + fieldobjs []pkFieldObjs + ) + for fieldName, field := range ormable.Fields { if field.GetTag().GetPrimaryKey() { - mapPK[fieldName] = field + fieldobjs = append(fieldobjs, pkFieldObjs{ + fieldName, + field, + }) } } - // consider field name "id" as well - for fieldName, field := range ormable.Fields { - if strings.ToLower(fieldName) == "id" { - mapPK[fieldName] = field + sort.Slice(fieldobjs, func(i, j int) bool { + return fieldobjs[i].name < fieldobjs[j].name + }) + // if no primary key is found, use the field named "id" + if len(fieldobjs) == 0 { + for fieldName, field := range ormable.Fields { + if strings.ToLower(fieldName) == "id" { + fieldobjs = append(fieldobjs, pkFieldObjs{ + fieldName, + field, + }) + } } } - return mapPK + return fieldobjs } func (b *ORMBuilder) getOrmable(typeName string) *OrmableType { @@ -1851,12 +1871,12 @@ func (b *ORMBuilder) generateReadHandler(message *protogen.Message, g *protogen. g.P(`return nil, err`) g.P(`}`) - mapPK := b.getPrimaryKeys(ormable) - for k, f := range mapPK { - if strings.Contains(f.TypeName, "*") { - g.P(`if ormObj.`, k, ` == nil || *ormObj.`, k, ` == `, b.guessZeroValue(f.TypeName, g), ` {`) + pkfieldMapping := b.getPrimaryKeys(ormable) + for _, pkfieldObj := range pkfieldMapping { + if strings.Contains(pkfieldObj.field.TypeName, "*") { + g.P(`if ormObj.`, pkfieldObj.name, ` == nil || *ormObj.`, pkfieldObj.name, ` == `, b.guessZeroValue(pkfieldObj.field.TypeName, g), ` {`) } else { - g.P(`if ormObj.`, k, ` == `, b.guessZeroValue(f.TypeName, g), ` {`) + g.P(`if ormObj.`, pkfieldObj.name, ` == `, b.guessZeroValue(pkfieldObj.field.TypeName, g), ` {`) } g.P(`return nil, `, "errors", `.EmptyIdError`) g.P(`}`) @@ -1992,12 +2012,12 @@ func (b *ORMBuilder) generateDeleteHandler(message *protogen.Message, g *protoge g.P(`}`) ormable := b.getOrmable(typeName) - mapPKs := b.getPrimaryKeys(ormable) - for pkName, pk := range mapPKs { - if strings.Contains(pk.TypeName, "*") { - g.P(`if ormObj.`, pkName, ` == nil || *ormObj.`, pkName, ` == `, b.guessZeroValue(pk.TypeName, g), ` {`) + pkFieldMapping := b.getPrimaryKeys(ormable) + for _, pkFieldObj := range pkFieldMapping { + if strings.Contains(pkFieldObj.field.TypeName, "*") { + g.P(`if ormObj.`, pkFieldObj.name, ` == nil || *ormObj.`, pkFieldObj.name, ` == `, b.guessZeroValue(pkFieldObj.field.TypeName, g), ` {`) } else { - g.P(`if ormObj.`, pkName, ` == `, b.guessZeroValue(pk.TypeName, g), `{`) + g.P(`if ormObj.`, pkFieldObj.name, ` == `, b.guessZeroValue(pkFieldObj.field.TypeName, g), `{`) } g.P(`return `, generateImport("EmptyIdError", gerrorsImport, g)) g.P(`}`) @@ -2672,12 +2692,12 @@ func (b *ORMBuilder) generateListHandler(message *protogen.Message, g *protogen. // TODO handle composite primary keys order considering priority tag if b.hasCompositePrimaryKey(ormable) { - pksMap := b.getPrimaryKeys(ormable) + pkFieldMapping := b.getPrimaryKeys(ormable) var columns []string - for fieldName, field := range pksMap { - column := field.GetTag().GetColumn() + for _, pkFieldObj := range pkFieldMapping { + column := pkFieldObj.field.GetTag().GetColumn() if len(column) == 0 { - column = gschema.NamingStrategy{SingularTable: true}.TableName(fieldName) + column = gschema.NamingStrategy{SingularTable: true}.TableName(pkFieldObj.name) } columns = append(columns, column) }