Skip to content

Commit

Permalink
fix database timezone issue
Browse files Browse the repository at this point in the history
  • Loading branch information
sijms committed Apr 25, 2024
1 parent 0538163 commit bbeff0e
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 24 deletions.
2 changes: 1 addition & 1 deletion v2/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ func (stmt *Stmt) getExeOption() int {
if !stmt.parse && !stmt.execute {
op |= 0x40
}
if len(stmt.Pars) > 0 {
if len(stmt.Pars) > 0 && !stmt.define {
op |= 0x8
if stmt.stmtType == PLSQL || (stmt._hasReturnClause && !stmt.reSendParDef) {
op |= 0x400
Expand Down
2 changes: 1 addition & 1 deletion v2/configurations/connect_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ func ParseConfig(dsn string) (*ConnectionConfig, error) {
} else if tempVal == "POST" || tempVal == "STREAM" {
config.Lob = STREAM
} else {
return nil, errors.New("LOB FETCH value should be: PRE(default) or POST or STREAM")
return nil, errors.New("LOB FETCH value should be either INLINE/PRE (default) or STREAM/POST")
}
case "LANGUAGE":
config.Language = val[0]
Expand Down
36 changes: 27 additions & 9 deletions v2/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ type Connection struct {
date int
timestamp int
}
bad bool
dbTimeLoc *time.Location
bad bool
dbTimeZone *time.Location
}

type OracleConnector struct {
Expand Down Expand Up @@ -490,17 +490,24 @@ func (conn *Connection) OpenWithContext(ctx context.Context) error {
return err
}
}
conn.getDBTimeZone()
return nil
}

func (conn *Connection) getDBTimeZone() {
var current time.Time
err := conn.QueryRowContext(context.Background(), "SELECT SYSTIMESTAMP FROM DUAL", nil).Scan(&current)
func (conn *Connection) getDBTimeZone() error {
var result string
err := conn.QueryRowContext(context.Background(), "SELECT DBTIMEZONE FROM DUAL", nil).Scan(&result)
//var current time.Time
//err := conn.QueryRowContext(context.Background(), "SELECT SYSTIMESTAMP FROM DUAL", nil).Scan(&current)
if err != nil {
conn.dbTimeLoc = time.UTC
return err
}
var tzHours, tzMin int
_, err = fmt.Sscanf(result, "%03d:%02d", &tzHours, &tzMin)
if err != nil {
return err
}
conn.dbTimeLoc = current.Location()
conn.dbTimeZone = time.FixedZone(result, tzHours*60*60+tzMin*60)
return nil
}

// Begin a transaction
Expand Down Expand Up @@ -1261,10 +1268,21 @@ func (conn *Connection) dataTypeNegotiation() error {
if err != nil {
return err
}
err = conn.dataNego.read(conn.session)
conn.dbTimeZone, err = conn.dataNego.read(conn.session)
if err != nil {
return err
}
if conn.dbTimeZone == nil {
conn.tracer.Print("DB timezone not retrieved in data type negotiation")
conn.tracer.Print("try to query DB timezone")
err = conn.getDBTimeZone()
if err != nil {
conn.tracer.Print("error during get DB timezone: ", err)
conn.tracer.Print("set DB timezone to: UTC(+00:00)")
conn.dbTimeZone = time.UTC
}
}
conn.tracer.Print("DB timezone: ", conn.dbTimeZone)
conn.session.TTCVersion = conn.dataNego.CompileTimeCaps[7]
conn.session.UseBigScn = conn.tcpNego.ServerCompileTimeCaps[7] >= 8
if conn.tcpNego.ServerCompileTimeCaps[7] < conn.session.TTCVersion {
Expand Down
29 changes: 20 additions & 9 deletions v2/data_type_nego.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ type DataTypeNego struct {
DataTypeRepFor1200 int16
CompileTimeCaps []byte
RuntimeCap []byte
DBTimeZone []byte
b32kTypeSupported bool
supportSessionStateOps bool
serverTZVersion int
Expand Down Expand Up @@ -462,19 +461,31 @@ func buildTypeNego(nego *TCPNego, session *network.Session) *DataTypeNego {
}
return &result
}
func (nego *DataTypeNego) read(session *network.Session) error {
msg, err := session.GetByte()
func (nego *DataTypeNego) read(session *network.Session) (zone *time.Location, err error) {
var msg uint8
msg, err = session.GetByte()
if err != nil {
return err
return
}
if msg != 2 {
return errors.New(fmt.Sprintf("message code error: received code %d and expected code is 2", msg))
err = errors.New(fmt.Sprintf("message code error: received code %d and expected code is 2", msg))
return
}
if nego.RuntimeCap[1] == 1 {
nego.DBTimeZone, err = session.GetBytes(11)
var tz_bytes []byte
tz_bytes, err = session.GetBytes(11)
if err != nil {
return err
return
}
if len(tz_bytes) < 11 {
err = errors.New("incorrect format for DBTimeZone")
return
}
tzHours := int(tz_bytes[4]) - 60
tzMin := int(tz_bytes[5]) - 60
tzSec := int(tz_bytes[6]) - 60
zone = time.FixedZone(fmt.Sprintf("%+03d:%02d", tzHours, tzMin),
tzHours*60*60+tzMin*60+tzSec)
if nego.CompileTimeCaps[37]&2 == 2 {
nego.serverTZVersion, _ = session.GetInt(4, false, true)
}
Expand Down Expand Up @@ -502,8 +513,8 @@ func (nego *DataTypeNego) read(session *network.Session) error {
}
//fmt.Println("server timezone version: ", nego.serverTZVersion)
//fmt.Println("client timezone version: ", nego.clientTZVersion)
//fmt.Println("server timezone: ", nego.DBTimeZone)
return nil
//fmt.Println("server timezone: ", nego.dbTimeZone)
return
}
func (nego *DataTypeNego) write(session *network.Session) error {
session.ResetBuffer()
Expand Down
8 changes: 4 additions & 4 deletions v2/parameter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1135,9 +1135,9 @@ func (par *ParameterInfo) decodePrimValue(conn *Connection, temporaryLobs *[][]b
if err != nil {
return err
}
if conn.dbTimeLoc != nil && conn.dbTimeLoc != time.UTC {
if conn.dbTimeZone != time.UTC {
par.oPrimValue = time.Date(tempTime.Year(), tempTime.Month(), tempTime.Day(),
tempTime.Hour(), tempTime.Minute(), tempTime.Second(), tempTime.Nanosecond(), conn.dbTimeLoc)
tempTime.Hour(), tempTime.Minute(), tempTime.Second(), tempTime.Nanosecond(), conn.dbTimeZone)
} else {
par.oPrimValue = tempTime
}
Expand All @@ -1153,8 +1153,8 @@ func (par *ParameterInfo) decodePrimValue(conn *Connection, temporaryLobs *[][]b
return err
}
par.oPrimValue = tempTime
if conn.dbTimeLoc != nil && conn.dbTimeLoc != time.UTC {
par.oPrimValue = tempTime.In(conn.dbTimeLoc)
if conn.dbTimeZone != time.UTC {
par.oPrimValue = tempTime.In(conn.dbTimeZone)
}
//case TimeStampDTY, TimeStampeLTZ, TimeStampLTZ_DTY, TIMESTAMPTZ, TimeStampTZ_DTY:
// fallthrough
Expand Down

0 comments on commit bbeff0e

Please sign in to comment.