diff --git a/go/arrow/array/table.go b/go/arrow/array/table.go index 88362c74f26eb..af0b1c250c0b3 100644 --- a/go/arrow/array/table.go +++ b/go/arrow/array/table.go @@ -132,9 +132,9 @@ func NewTable(schema *arrow.Schema, cols []arrow.Column, rows int64) *simpleTabl // of slices of arrow.Array. // // Like other NewTable functions this can panic if: -// - len(schema.Fields) != len(data) -// - the total length of each column's array slice (ie: number of rows -// in the column) aren't the same for all columns. +// - len(schema.Fields) != len(data) +// - the total length of each column's array slice (ie: number of rows +// in the column) aren't the same for all columns. func NewTableFromSlice(schema *arrow.Schema, data [][]arrow.Array) *simpleTable { if len(data) != len(schema.Fields()) { panic("array/table: mismatch in number of columns and data for creating a table") @@ -197,7 +197,27 @@ func NewTableFromRecords(schema *arrow.Schema, recs []arrow.Record) *simpleTable return NewTable(schema, cols, -1) } -func (tbl *simpleTable) Schema() *arrow.Schema { return tbl.schema } +func (tbl *simpleTable) Schema() *arrow.Schema { return tbl.schema } + +func (tbl *simpleTable) AddColumn(i int, field arrow.Field, column arrow.Column) (arrow.Table, error) { + if int64(column.Len()) != tbl.rows { + return nil, fmt.Errorf("arrow/array: column length mismatch: %d != %d", column.Len(), tbl.rows) + } + if field.Type != column.DataType() { + return nil, fmt.Errorf("arrow/array: column type mismatch: %v != %v", field.Type, column.DataType()) + } + newSchema, err := tbl.schema.AddField(i, field) + if err != nil { + return nil, err + } + cols := make([]arrow.Column, len(tbl.cols)+1) + copy(cols[:i], tbl.cols[:i]) + cols[i] = column + copy(cols[i+1:], tbl.cols[i:]) + newTable := NewTable(newSchema, cols, tbl.rows) + return newTable, nil +} + func (tbl *simpleTable) NumRows() int64 { return tbl.rows } func (tbl *simpleTable) NumCols() int64 { return int64(len(tbl.cols)) } func (tbl *simpleTable) Column(i int) *arrow.Column { return &tbl.cols[i] } diff --git a/go/arrow/array/table_test.go b/go/arrow/array/table_test.go index 67cab2066e297..752d1c424a4d7 100644 --- a/go/arrow/array/table_test.go +++ b/go/arrow/array/table_test.go @@ -404,6 +404,12 @@ func TestTable(t *testing.T) { mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) defer mem.AssertSize(t, 0) + preSchema := arrow.NewSchema( + []arrow.Field{ + {Name: "f1-i32", Type: arrow.PrimitiveTypes.Int32}, + }, + nil, + ) schema := arrow.NewSchema( []arrow.Field{ {Name: "f1-i32", Type: arrow.PrimitiveTypes.Int32}, @@ -469,8 +475,17 @@ func TestTable(t *testing.T) { slices := [][]arrow.Array{col1.Data().Chunks(), col2.Data().Chunks()} - tbl := array.NewTable(schema, cols, -1) + preTbl := array.NewTable(preSchema, []arrow.Column{*col1}, -1) + defer preTbl.Release() + tbl, err := preTbl.AddColumn( + 1, + arrow.Field{Name: "f2-f64", Type: arrow.PrimitiveTypes.Float64}, + *col2, + ) defer tbl.Release() + if err != nil { + t.Fatalf("could not add column: %+v", err) + } tbl2 := array.NewTableFromSlice(schema, slices) defer tbl2.Release() diff --git a/go/arrow/schema.go b/go/arrow/schema.go index 87bfe2b47a450..c10530ba822e1 100644 --- a/go/arrow/schema.go +++ b/go/arrow/schema.go @@ -232,6 +232,19 @@ func (sc *Schema) Equal(o *Schema) bool { return true } +// AddField adds a field at the given index and return a new schema. +func (s *Schema) AddField(i int, field Field) (*Schema, error) { + if i < 0 || i > len(s.fields) { + return nil, fmt.Errorf("arrow: invalid field index %d", i) + } + + fields := make([]Field, len(s.fields)+1) + copy(fields[:i], s.fields[:i]) + fields[i] = field + copy(fields[i+1:], s.fields[i:]) + return NewSchema(fields, &s.meta), nil +} + func (s *Schema) String() string { o := new(strings.Builder) fmt.Fprintf(o, "schema:\n fields: %d\n", len(s.Fields())) diff --git a/go/arrow/schema_test.go b/go/arrow/schema_test.go index 1ef5c6432b7f4..201353d128b67 100644 --- a/go/arrow/schema_test.go +++ b/go/arrow/schema_test.go @@ -313,6 +313,30 @@ func TestSchema(t *testing.T) { } } +func TestSchemaAddField(t *testing.T) { + s := NewSchema([]Field{ + {Name: "f1", Type: PrimitiveTypes.Int32}, + {Name: "f2", Type: PrimitiveTypes.Int64}, + }, nil) + + _, err := s.AddField(3, Field{Name: "f3", Type: PrimitiveTypes.Int32}) + if err == nil { + t.Fatalf("expected an error") + } + + s, err = s.AddField(2, Field{Name: "f3", Type: PrimitiveTypes.Int32}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got, want := len(s.Fields()), 3; got != want { + t.Fatalf("invalid number of fields. got=%d, want=%d", got, want) + } + got, want := s.Field(2), Field{Name: "f3", Type: PrimitiveTypes.Int32}; + if !got.Equal(want) { + t.Fatalf("invalid field: got=%#v, want=%#v", got, want) + } +} + func TestSchemaEqual(t *testing.T) { fields := []Field{ {Name: "f1", Type: PrimitiveTypes.Int32}, diff --git a/go/arrow/table.go b/go/arrow/table.go index 0d20d955ab796..afddf277f0194 100644 --- a/go/arrow/table.go +++ b/go/arrow/table.go @@ -33,6 +33,10 @@ type Table interface { NumCols() int64 Column(i int) *Column + // AddColumn adds a new column to the table and a corresponding field (of the same type) + // to its schema, at the specified position. Returns the new table with updated columns and schema. + AddColumn(pos int, f Field, c Column) (Table, error) + Retain() Release() }