From 6c9bd7da4582c54adad27989e2b24861fde73b49 Mon Sep 17 00:00:00 2001 From: pm47 Date: Fri, 21 May 2021 17:34:57 +0200 Subject: [PATCH] Use schemas in Postgres Instead of having a flat hierarchy under the default `public` schema. There is roughly a schema per database type. The new hierarchy is: - `local` - `channels` - `htlc_infos` - `pending_settlement_commands` - `peers` - `network` - `nodes` - `public_channels` - `pruned_channels` - `payments` - `received` - `sent` - `audit` - - `public` - `lease` - `versions` Note in particular, the change in naming for local channels vs external channels: - `local_channels` -> `local.channels` - `channels` -> `network.public_channels` The two internal tables `lease` and `versions` stay in the `public` schema, because we have no meta way of migrating them. --- .../scala/fr/acinq/eclair/db/Databases.scala | 8 +- .../fr/acinq/eclair/db/pg/PgAuditDb.scala | 76 ++++++++----- .../fr/acinq/eclair/db/pg/PgChannelsDb.scala | 52 +++++---- .../fr/acinq/eclair/db/pg/PgNetworkDb.scala | 100 +++++++++++------- .../fr/acinq/eclair/db/pg/PgPaymentsDb.scala | 64 ++++++----- .../fr/acinq/eclair/db/pg/PgPeersDb.scala | 26 +++-- .../eclair/db/pg/PgPendingCommandsDb.scala | 22 ++-- .../scala/fr/acinq/eclair/TestDatabases.scala | 3 +- .../fr/acinq/eclair/db/AuditDbSpec.scala | 17 ++- .../fr/acinq/eclair/db/ChannelsDbSpec.scala | 14 +-- .../fr/acinq/eclair/db/NetworkDbSpec.scala | 12 +-- .../eclair/db/PendingCommandsDbSpec.scala | 4 +- .../fr/acinq/eclair/db/PgUtilsSpec.scala | 13 +++ 13 files changed, 259 insertions(+), 152 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala index 29b26ca90d..9dc8b40122 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala @@ -127,8 +127,12 @@ object Databases extends Logging { readOnlyUser_opt.foreach { readOnlyUser => PgUtils.inTransaction { connection => using(connection.createStatement()) { statement => - logger.info(s"granting read-only access to user=$readOnlyUser") - statement.executeUpdate(s"GRANT SELECT ON ALL TABLES IN SCHEMA public TO $readOnlyUser") + val schemas = "public" :: "audit" :: "local" :: "network" :: "payments" :: Nil + schemas.foreach { schema => + logger.info(s"granting read-only access to user=$readOnlyUser schema=$schema") + statement.executeUpdate(s"GRANT USAGE ON SCHEMA $schema TO $readOnlyUser") + statement.executeUpdate(s"GRANT SELECT ON ALL TABLES IN SCHEMA $schema TO $readOnlyUser") + } } } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala index 44852abdb8..0a7c29a4ce 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala @@ -40,7 +40,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { import ExtendedResultSet._ val DB_NAME = "audit" - val CURRENT_VERSION = 6 + val CURRENT_VERSION = 7 case class RelayedPart(channelId: ByteVector32, amount: MilliSatoshi, direction: String, relayType: String, timestamp: Long) @@ -62,32 +62,50 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.executeUpdate("ALTER TABLE channel_errors ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") } + def migration67(statement: Statement): Unit = { + statement.executeUpdate("CREATE SCHEMA audit") + statement.executeUpdate("ALTER TABLE sent SET SCHEMA audit") + statement.executeUpdate("ALTER TABLE received SET SCHEMA audit") + statement.executeUpdate("ALTER TABLE relayed SET SCHEMA audit") + statement.executeUpdate("ALTER TABLE relayed_trampoline SET SCHEMA audit") + statement.executeUpdate("ALTER TABLE network_fees SET SCHEMA audit") + statement.executeUpdate("ALTER TABLE channel_events SET SCHEMA audit") + statement.executeUpdate("ALTER TABLE channel_errors SET SCHEMA audit") + } + getVersion(statement, DB_NAME) match { case None => - statement.executeUpdate("CREATE TABLE sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") - statement.executeUpdate("CREATE TABLE received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") - statement.executeUpdate("CREATE TABLE relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") - statement.executeUpdate("CREATE TABLE relayed_trampoline (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, next_node_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") - statement.executeUpdate("CREATE TABLE network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") - statement.executeUpdate("CREATE TABLE channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") - statement.executeUpdate("CREATE TABLE channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE SCHEMA audit") - statement.executeUpdate("CREATE INDEX sent_timestamp_idx ON sent(timestamp)") - statement.executeUpdate("CREATE INDEX received_timestamp_idx ON received(timestamp)") - statement.executeUpdate("CREATE INDEX relayed_timestamp_idx ON relayed(timestamp)") - statement.executeUpdate("CREATE INDEX relayed_payment_hash_idx ON relayed(payment_hash)") - statement.executeUpdate("CREATE INDEX relayed_trampoline_timestamp_idx ON relayed_trampoline(timestamp)") - statement.executeUpdate("CREATE INDEX relayed_trampoline_payment_hash_idx ON relayed_trampoline(payment_hash)") - statement.executeUpdate("CREATE INDEX network_fees_timestamp_idx ON network_fees(timestamp)") - statement.executeUpdate("CREATE INDEX channel_events_timestamp_idx ON channel_events(timestamp)") - statement.executeUpdate("CREATE INDEX channel_errors_timestamp_idx ON channel_errors(timestamp)") + statement.executeUpdate("CREATE TABLE audit.sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE audit.received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE audit.relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE audit.relayed_trampoline (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, next_node_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE audit.network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE audit.channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + + statement.executeUpdate("CREATE TABLE audit.channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE INDEX sent_timestamp_idx ON audit.sent(timestamp)") + statement.executeUpdate("CREATE INDEX received_timestamp_idx ON audit.received(timestamp)") + statement.executeUpdate("CREATE INDEX relayed_timestamp_idx ON audit.relayed(timestamp)") + statement.executeUpdate("CREATE INDEX relayed_payment_hash_idx ON audit.relayed(payment_hash)") + statement.executeUpdate("CREATE INDEX relayed_trampoline_timestamp_idx ON audit.relayed_trampoline(timestamp)") + statement.executeUpdate("CREATE INDEX relayed_trampoline_payment_hash_idx ON audit.relayed_trampoline(payment_hash)") + statement.executeUpdate("CREATE INDEX network_fees_timestamp_idx ON audit.network_fees(timestamp)") + statement.executeUpdate("CREATE INDEX channel_events_timestamp_idx ON audit.channel_events(timestamp)") + statement.executeUpdate("CREATE INDEX channel_errors_timestamp_idx ON audit.channel_errors(timestamp)") case Some(v@4) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") migration45(statement) migration56(statement) + migration67(statement) case Some(v@5) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") migration56(statement) + migration67(statement) + case Some(v@6) => + logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") + migration67(statement) case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -97,7 +115,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def add(e: ChannelEvent): Unit = withMetrics("audit/add-channel-lifecycle", DbBackends.Postgres) { inTransaction { pg => - using(pg.prepareStatement("INSERT INTO channel_events VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement => + using(pg.prepareStatement("INSERT INTO audit.channel_events VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement => statement.setString(1, e.channelId.toHex) statement.setString(2, e.remoteNodeId.value.toHex) statement.setLong(3, e.capacity.toLong) @@ -112,7 +130,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def add(e: PaymentSent): Unit = withMetrics("audit/add-payment-sent", DbBackends.Postgres) { inTransaction { pg => - using(pg.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement => + using(pg.prepareStatement("INSERT INTO audit.sent VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement => e.parts.foreach(p => { statement.setLong(1, p.amount.toLong) statement.setLong(2, p.feesPaid.toLong) @@ -133,7 +151,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def add(e: PaymentReceived): Unit = withMetrics("audit/add-payment-received", DbBackends.Postgres) { inTransaction { pg => - using(pg.prepareStatement("INSERT INTO received VALUES (?, ?, ?, ?)")) { statement => + using(pg.prepareStatement("INSERT INTO audit.received VALUES (?, ?, ?, ?)")) { statement => e.parts.foreach(p => { statement.setLong(1, p.amount.toLong) statement.setString(2, e.paymentHash.toHex) @@ -153,7 +171,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { // non-trampoline relayed payments have one input and one output Seq(RelayedPart(fromChannelId, amountIn, "IN", "channel", ts), RelayedPart(toChannelId, amountOut, "OUT", "channel", ts)) case TrampolinePaymentRelayed(_, incoming, outgoing, nextTrampolineNodeId, nextTrampolineAmount, ts) => - using(pg.prepareStatement("INSERT INTO relayed_trampoline VALUES (?, ?, ?, ?)")) { statement => + using(pg.prepareStatement("INSERT INTO audit.relayed_trampoline VALUES (?, ?, ?, ?)")) { statement => statement.setString(1, e.paymentHash.toHex) statement.setLong(2, nextTrampolineAmount.toLong) statement.setString(3, nextTrampolineNodeId.value.toHex) @@ -164,7 +182,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { incoming.map(i => RelayedPart(i.channelId, i.amount, "IN", "trampoline", ts)) ++ outgoing.map(o => RelayedPart(o.channelId, o.amount, "OUT", "trampoline", ts)) } for (p <- payments) { - using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => + using(pg.prepareStatement("INSERT INTO audit.relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => statement.setString(1, e.paymentHash.toHex) statement.setLong(2, p.amount.toLong) statement.setString(3, p.channelId.toHex) @@ -179,7 +197,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def add(e: NetworkFeePaid): Unit = withMetrics("audit/add-network-fee", DbBackends.Postgres) { inTransaction { pg => - using(pg.prepareStatement("INSERT INTO network_fees VALUES (?, ?, ?, ?, ?, ?)")) { statement => + using(pg.prepareStatement("INSERT INTO audit.network_fees VALUES (?, ?, ?, ?, ?, ?)")) { statement => statement.setString(1, e.channelId.toHex) statement.setString(2, e.remoteNodeId.value.toHex) statement.setString(3, e.tx.txid.toHex) @@ -193,7 +211,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def add(e: ChannelErrorOccurred): Unit = withMetrics("audit/add-channel-error", DbBackends.Postgres) { inTransaction { pg => - using(pg.prepareStatement("INSERT INTO channel_errors VALUES (?, ?, ?, ?, ?, ?)")) { statement => + using(pg.prepareStatement("INSERT INTO audit.channel_errors VALUES (?, ?, ?, ?, ?, ?)")) { statement => val (errorName, errorMessage) = e.error match { case LocalError(t) => (t.getClass.getSimpleName, t.getMessage) case RemoteError(error) => ("remote", error.toAscii) @@ -211,7 +229,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def listSent(from: Long, to: Long): Seq[PaymentSent] = inTransaction { pg => - using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp BETWEEN ? AND ?")) { statement => + using(pg.prepareStatement("SELECT * FROM audit.sent WHERE timestamp BETWEEN ? AND ?")) { statement => statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) statement.executeQuery() @@ -241,7 +259,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def listReceived(from: Long, to: Long): Seq[PaymentReceived] = inTransaction { pg => - using(pg.prepareStatement("SELECT * FROM received WHERE timestamp BETWEEN ? AND ?")) { statement => + using(pg.prepareStatement("SELECT * FROM audit.received WHERE timestamp BETWEEN ? AND ?")) { statement => statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) statement.executeQuery() @@ -262,7 +280,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def listRelayed(from: Long, to: Long): Seq[PaymentRelayed] = inTransaction { pg => - val trampolineByHash = using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ?")) { statement => + val trampolineByHash = using(pg.prepareStatement("SELECT * FROM audit.relayed_trampoline WHERE timestamp BETWEEN ? and ?")) { statement => statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) statement.executeQuery() @@ -273,7 +291,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { trampolineByHash + (paymentHash -> (amount, nodeId)) } } - val relayedByHash = using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp BETWEEN ? and ?")) { statement => + val relayedByHash = using(pg.prepareStatement("SELECT * FROM audit.relayed WHERE timestamp BETWEEN ? and ?")) { statement => statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) statement.executeQuery() @@ -308,7 +326,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def listNetworkFees(from: Long, to: Long): Seq[NetworkFee] = inTransaction { pg => - using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp BETWEEN ? and ? ORDER BY timestamp")) { statement => + using(pg.prepareStatement("SELECT * FROM audit.network_fees WHERE timestamp BETWEEN ? and ? ORDER BY timestamp")) { statement => statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) statement.executeQuery().map { rs => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala index 852baee843..0b11c48f7f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala @@ -41,7 +41,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit import lock._ val DB_NAME = "channels" - val CURRENT_VERSION = 5 + val CURRENT_VERSION = 6 inTransaction { pg => using(pg.createStatement()) { statement => @@ -66,32 +66,47 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit def migration45(statement: Statement): Unit = { statement.executeUpdate("ALTER TABLE local_channels ADD COLUMN json JSONB") - resetJsonColumns(pg) + resetJsonColumns(pg, oldTableName = true) statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN json SET NOT NULL") statement.executeUpdate("CREATE INDEX local_channels_type_idx ON local_channels ((json->>'type'))") statement.executeUpdate("CREATE INDEX local_channels_remote_node_id_idx ON local_channels ((json->'commitments'->'remoteParams'->>'nodeId'))") } + def migration56(statement: Statement): Unit = { + statement.executeUpdate("CREATE SCHEMA IF NOT EXISTS local") + statement.executeUpdate("ALTER TABLE local_channels SET SCHEMA local") + statement.executeUpdate("ALTER TABLE local.local_channels RENAME TO channels") + statement.executeUpdate("ALTER TABLE htlc_infos SET SCHEMA local") + } + getVersion(statement, DB_NAME) match { case None => - statement.executeUpdate("CREATE TABLE local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, json JSONB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE, created_timestamp TIMESTAMP WITH TIME ZONE, last_payment_sent_timestamp TIMESTAMP WITH TIME ZONE, last_payment_received_timestamp TIMESTAMP WITH TIME ZONE, last_connected_timestamp TIMESTAMP WITH TIME ZONE, closed_timestamp TIMESTAMP WITH TIME ZONE)") - statement.executeUpdate("CREATE TABLE htlc_infos (channel_id TEXT NOT NULL, commitment_number BIGINT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") + statement.executeUpdate("CREATE SCHEMA IF NOT EXISTS local") - statement.executeUpdate("CREATE INDEX local_channels_type_idx ON local_channels ((json->>'type'))") - statement.executeUpdate("CREATE INDEX local_channels_remote_node_id_idx ON local_channels ((json->'commitments'->'remoteParams'->>'nodeId'))") - statement.executeUpdate("CREATE INDEX htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") + statement.executeUpdate("CREATE TABLE local.channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, json JSONB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE, created_timestamp TIMESTAMP WITH TIME ZONE, last_payment_sent_timestamp TIMESTAMP WITH TIME ZONE, last_payment_received_timestamp TIMESTAMP WITH TIME ZONE, last_connected_timestamp TIMESTAMP WITH TIME ZONE, closed_timestamp TIMESTAMP WITH TIME ZONE)") + statement.executeUpdate("CREATE TABLE local.htlc_infos (channel_id TEXT NOT NULL, commitment_number BIGINT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local.channels(channel_id))") + + statement.executeUpdate("CREATE INDEX local_channels_type_idx ON local.channels ((json->>'type'))") + statement.executeUpdate("CREATE INDEX local_channels_remote_node_id_idx ON local.channels ((json->'commitments'->'remoteParams'->>'nodeId'))") + statement.executeUpdate("CREATE INDEX htlc_infos_idx ON local.htlc_infos(channel_id, commitment_number)") case Some(v@2) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") migration23(statement) migration34(statement) migration45(statement) + migration56(statement) case Some(v@3) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") migration34(statement) migration45(statement) + migration56(statement) case Some(v@4) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") migration45(statement) + migration56(statement) + case Some(v@5) => + logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") + migration56(statement) case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -100,10 +115,11 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit } /** Sometimes we may want to do a full reset when we update the json format */ - def resetJsonColumns(connection: Connection): Unit = { + def resetJsonColumns(connection: Connection, oldTableName: Boolean = false): Unit = { + val table = if (oldTableName) "local_channels" else "local.channels" migrateTable(connection, connection, - "local_channels", - "UPDATE local_channels SET json=?::JSONB WHERE channel_id=?", + table, + s"UPDATE $table SET json=?::JSONB WHERE channel_id=?", (rs, statement) => { val state = stateDataCodec.decode(BitVector(rs.getBytes("data"))).require.value val json = serialization.writePretty(state) @@ -118,7 +134,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit val data = stateDataCodec.encode(state).require.toByteArray using(pg.prepareStatement( """ - | INSERT INTO local_channels (channel_id, data, json, is_closed) + | INSERT INTO local.channels (channel_id, data, json, is_closed) | VALUES (?, ?, ?::JSONB, FALSE) | ON CONFLICT (channel_id) | DO UPDATE SET data = EXCLUDED.data, json = EXCLUDED.json ; @@ -136,7 +152,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit */ private def updateChannelMetaTimestampColumn(channelId: ByteVector32, columnName: String): Unit = { inTransaction(IsolationLevel.TRANSACTION_READ_UNCOMMITTED) { pg => - using(pg.prepareStatement(s"UPDATE local_channels SET $columnName=? WHERE channel_id=?")) { statement => + using(pg.prepareStatement(s"UPDATE local.channels SET $columnName=? WHERE channel_id=?")) { statement => statement.setTimestamp(1, Timestamp.from(Instant.now())) statement.setString(2, channelId.toHex) statement.executeUpdate() @@ -158,17 +174,17 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit override def removeChannel(channelId: ByteVector32): Unit = withMetrics("channels/remove-channel", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("DELETE FROM pending_settlement_commands WHERE channel_id=?")) { statement => + using(pg.prepareStatement("DELETE FROM local.pending_settlement_commands WHERE channel_id=?")) { statement => statement.setString(1, channelId.toHex) statement.executeUpdate() } - using(pg.prepareStatement("DELETE FROM htlc_infos WHERE channel_id=?")) { statement => + using(pg.prepareStatement("DELETE FROM local.htlc_infos WHERE channel_id=?")) { statement => statement.setString(1, channelId.toHex) statement.executeUpdate() } - using(pg.prepareStatement("UPDATE local_channels SET is_closed=TRUE WHERE channel_id=?")) { statement => + using(pg.prepareStatement("UPDATE local.channels SET is_closed=TRUE WHERE channel_id=?")) { statement => statement.setString(1, channelId.toHex) statement.executeUpdate() } @@ -178,7 +194,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit override def listLocalChannels(): Seq[HasCommitments] = withMetrics("channels/list-local-channels", DbBackends.Postgres) { withLock { pg => using(pg.createStatement) { statement => - statement.executeQuery("SELECT data FROM local_channels WHERE is_closed=FALSE") + statement.executeQuery("SELECT data FROM local.channels WHERE is_closed=FALSE") .mapCodec(stateDataCodec).toSeq } } @@ -186,7 +202,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit override def addHtlcInfo(channelId: ByteVector32, commitmentNumber: Long, paymentHash: ByteVector32, cltvExpiry: CltvExpiry): Unit = withMetrics("channels/add-htlc-info", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("INSERT INTO htlc_infos VALUES (?, ?, ?, ?)")) { statement => + using(pg.prepareStatement("INSERT INTO local.htlc_infos VALUES (?, ?, ?, ?)")) { statement => statement.setString(1, channelId.toHex) statement.setLong(2, commitmentNumber) statement.setString(3, paymentHash.toHex) @@ -198,7 +214,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit override def listHtlcInfos(channelId: ByteVector32, commitmentNumber: Long): Seq[(ByteVector32, CltvExpiry)] = withMetrics("channels/list-htlc-infos", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("SELECT payment_hash, cltv_expiry FROM htlc_infos WHERE channel_id=? AND commitment_number=?")) { statement => + using(pg.prepareStatement("SELECT payment_hash, cltv_expiry FROM local.htlc_infos WHERE channel_id=? AND commitment_number=?")) { statement => statement.setString(1, channelId.toHex) statement.setLong(2, commitmentNumber) statement.executeQuery diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala index 92b81c8b9b..0953f6c2fe 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala @@ -38,7 +38,7 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { import fr.acinq.eclair.json.JsonSerializers.{formats, serialization} val DB_NAME = "network" - val CURRENT_VERSION = 3 + val CURRENT_VERSION = 4 inTransaction { pg => using(pg.createStatement()) { statement => @@ -48,19 +48,33 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { statement.executeUpdate("ALTER TABLE channels ADD COLUMN channel_announcement_json JSONB") statement.executeUpdate("ALTER TABLE channels ADD COLUMN channel_update_1_json JSONB") statement.executeUpdate("ALTER TABLE channels ADD COLUMN channel_update_2_json JSONB") - resetJsonColumns(pg) + resetJsonColumns(pg, oldTableName = true) statement.executeUpdate("ALTER TABLE nodes ALTER COLUMN json SET NOT NULL") statement.executeUpdate("ALTER TABLE channels ALTER COLUMN channel_announcement_json SET NOT NULL") } + def migration34(statement: Statement): Unit = { + statement.executeUpdate("CREATE SCHEMA network") + statement.executeUpdate("ALTER TABLE nodes SET SCHEMA network") + statement.executeUpdate("ALTER TABLE channels RENAME TO public_channels") + statement.executeUpdate("ALTER TABLE public_channels SET SCHEMA network") + statement.executeUpdate("ALTER TABLE pruned RENAME TO pruned_channels") + statement.executeUpdate("ALTER TABLE pruned_channels SET SCHEMA network") + } + getVersion(statement, DB_NAME) match { case None => - statement.executeUpdate("CREATE TABLE nodes (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, json JSONB NOT NULL)") - statement.executeUpdate("CREATE TABLE channels (short_channel_id BIGINT NOT NULL PRIMARY KEY, txid TEXT NOT NULL, channel_announcement BYTEA NOT NULL, capacity_sat BIGINT NOT NULL, channel_update_1 BYTEA NULL, channel_update_2 BYTEA NULL, channel_announcement_json JSONB NOT NULL, channel_update_1_json JSONB NULL, channel_update_2_json JSONB NULL)") - statement.executeUpdate("CREATE TABLE pruned (short_channel_id BIGINT NOT NULL PRIMARY KEY)") + statement.executeUpdate("CREATE SCHEMA network") + statement.executeUpdate("CREATE TABLE network.nodes (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, json JSONB NOT NULL)") + statement.executeUpdate("CREATE TABLE network.public_channels (short_channel_id BIGINT NOT NULL PRIMARY KEY, txid TEXT NOT NULL, channel_announcement BYTEA NOT NULL, capacity_sat BIGINT NOT NULL, channel_update_1 BYTEA NULL, channel_update_2 BYTEA NULL, channel_announcement_json JSONB NOT NULL, channel_update_1_json JSONB NULL, channel_update_2_json JSONB NULL)") + statement.executeUpdate("CREATE TABLE network.pruned_channels (short_channel_id BIGINT NOT NULL PRIMARY KEY)") case Some(v@2) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") migration23(statement) + migration34(statement) + case Some(v@3) => + logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") + migration34(statement) case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -69,10 +83,12 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { } /** Sometimes we may want to do a full reset when we update the json format */ - def resetJsonColumns(connection: Connection): Unit = { + def resetJsonColumns(connection: Connection, oldTableName: Boolean = false): Unit = { + val nodesTable = if (oldTableName) "nodes" else "network.nodes" + val channelsTable = if (oldTableName) "channels" else "network.public_channels" migrateTable(connection, connection, - "nodes", - "UPDATE nodes SET json=?::JSON WHERE node_id=?", + nodesTable, + s"UPDATE $nodesTable SET json=?::JSON WHERE node_id=?", (rs, statement) => { val node = nodeAnnouncementCodec.decode(BitVector(rs.getBytes("data"))).require.value val json = serialization.writePretty(node) @@ -81,8 +97,8 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { } )(logger) migrateTable(connection, connection, - "channels", - "UPDATE channels SET channel_announcement_json=?::JSON, channel_update_1_json=?::JSON, channel_update_2_json=?::JSON WHERE short_channel_id=?", + channelsTable, + s"UPDATE $channelsTable SET channel_announcement_json=?::JSON, channel_update_1_json=?::JSON, channel_update_2_json=?::JSON WHERE short_channel_id=?", (rs, statement) => { val ann = channelAnnouncementCodec.decode(rs.getBitVectorOpt("channel_announcement").get).require.value val channel_update_1_opt = rs.getBitVectorOpt("channel_update_1").map(channelUpdateCodec.decode(_).require.value) @@ -100,29 +116,31 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { override def addNode(n: NodeAnnouncement): Unit = withMetrics("network/add-node", DbBackends.Postgres) { inTransaction { pg => - using(pg.prepareStatement("INSERT INTO nodes (node_id, data, json) VALUES (?, ?, ?::JSONB) ON CONFLICT DO NOTHING")) { statement => - statement.setString(1, n.nodeId.value.toHex) - statement.setBytes(2, nodeAnnouncementCodec.encode(n).require.toByteArray) - statement.setString(3, serialization.writePretty(n)) - statement.executeUpdate() + using(pg.prepareStatement("INSERT INTO network.nodes (node_id, data, json) VALUES (?, ?, ?::JSONB) ON CONFLICT DO NOTHING")) { + statement => + statement.setString(1, n.nodeId.value.toHex) + statement.setBytes(2, nodeAnnouncementCodec.encode(n).require.toByteArray) + statement.setString(3, serialization.writePretty(n)) + statement.executeUpdate() } } } override def updateNode(n: NodeAnnouncement): Unit = withMetrics("network/update-node", DbBackends.Postgres) { inTransaction { pg => - using(pg.prepareStatement("UPDATE nodes SET data=?, json=?::JSONB WHERE node_id=?")) { statement => - statement.setBytes(1, nodeAnnouncementCodec.encode(n).require.toByteArray) - statement.setString(2, serialization.writePretty(n)) - statement.setString(3, n.nodeId.value.toHex) - statement.executeUpdate() + using(pg.prepareStatement("UPDATE network.nodes SET data=?, json=?::JSONB WHERE node_id=?")) { + statement => + statement.setBytes(1, nodeAnnouncementCodec.encode(n).require.toByteArray) + statement.setString(2, serialization.writePretty(n)) + statement.setString(3, n.nodeId.value.toHex) + statement.executeUpdate() } } } override def getNode(nodeId: Crypto.PublicKey): Option[NodeAnnouncement] = withMetrics("network/get-node", DbBackends.Postgres) { inTransaction { pg => - using(pg.prepareStatement("SELECT data FROM nodes WHERE node_id=?")) { statement => + using(pg.prepareStatement("SELECT data FROM network.nodes WHERE node_id=?")) { statement => statement.setString(1, nodeId.value.toHex) statement.executeQuery() .mapCodec(nodeAnnouncementCodec) @@ -133,7 +151,7 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { override def removeNode(nodeId: Crypto.PublicKey): Unit = withMetrics("network/remove-node", DbBackends.Postgres) { inTransaction { pg => - using(pg.prepareStatement("DELETE FROM nodes WHERE node_id=?")) { + using(pg.prepareStatement("DELETE FROM network.nodes WHERE node_id=?")) { statement => statement.setString(1, nodeId.value.toHex) statement.executeUpdate() @@ -144,7 +162,7 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { override def listNodes(): Seq[NodeAnnouncement] = withMetrics("network/list-nodes", DbBackends.Postgres) { inTransaction { pg => using(pg.createStatement()) { statement => - statement.executeQuery("SELECT data FROM nodes") + statement.executeQuery("SELECT data FROM network.nodes") .mapCodec(nodeAnnouncementCodec).toSeq } } @@ -152,13 +170,14 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { override def addChannel(c: ChannelAnnouncement, txid: ByteVector32, capacity: Satoshi): Unit = withMetrics("network/add-channel", DbBackends.Postgres) { inTransaction { pg => - using(pg.prepareStatement("INSERT INTO channels(short_channel_id, txid, channel_announcement, capacity_sat, channel_announcement_json) VALUES (?, ?, ?, ?, ?::JSONB) ON CONFLICT DO NOTHING")) { statement => - statement.setLong(1, c.shortChannelId.toLong) - statement.setString(2, txid.toHex) - statement.setBytes(3, channelAnnouncementCodec.encode(c).require.toByteArray) - statement.setLong(4, capacity.toLong) - statement.setString(5, serialization.writePretty(c)) - statement.executeUpdate() + using(pg.prepareStatement("INSERT INTO network.public_channels (short_channel_id, txid, channel_announcement, capacity_sat, channel_announcement_json) VALUES (?, ?, ?, ?, ?::JSONB) ON CONFLICT DO NOTHING")) { + statement => + statement.setLong(1, c.shortChannelId.toLong) + statement.setString(2, txid.toHex) + statement.setBytes(3, channelAnnouncementCodec.encode(c).require.toByteArray) + statement.setLong(4, capacity.toLong) + statement.setString(5, serialization.writePretty(c)) + statement.executeUpdate() } } } @@ -166,11 +185,12 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { override def updateChannel(u: ChannelUpdate): Unit = withMetrics("network/update-channel", DbBackends.Postgres) { val column = if (u.isNode1) "channel_update_1" else "channel_update_2" inTransaction { pg => - using(pg.prepareStatement(s"UPDATE channels SET $column=?, ${column}_json=?::JSONB WHERE short_channel_id=?")) { statement => - statement.setBytes(1, channelUpdateCodec.encode(u).require.toByteArray) - statement.setString(2, serialization.writePretty(u)) - statement.setLong(3, u.shortChannelId.toLong) - statement.executeUpdate() + using(pg.prepareStatement(s"UPDATE network.public_channels SET $column=?, ${column}_json=?::JSONB WHERE short_channel_id=?")) { + statement => + statement.setBytes(1, channelUpdateCodec.encode(u).require.toByteArray) + statement.setString(2, serialization.writePretty(u)) + statement.setLong(3, u.shortChannelId.toLong) + statement.executeUpdate() } } } @@ -178,7 +198,7 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { override def listChannels(): SortedMap[ShortChannelId, PublicChannel] = withMetrics("network/list-channels", DbBackends.Postgres) { inTransaction { pg => using(pg.createStatement()) { statement => - statement.executeQuery("SELECT channel_announcement, txid, capacity_sat, channel_update_1, channel_update_2 FROM channels") + statement.executeQuery("SELECT channel_announcement, txid, capacity_sat, channel_update_1, channel_update_2 FROM network.public_channels") .foldLeft(SortedMap.empty[ShortChannelId, PublicChannel]) { (m, rs) => val ann = channelAnnouncementCodec.decode(rs.getBitVectorOpt("channel_announcement").get).require.value val txId = ByteVector32.fromValidHex(rs.getString("txid")) @@ -194,7 +214,7 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { override def removeChannels(shortChannelIds: Iterable[ShortChannelId]): Unit = withMetrics("network/remove-channels", DbBackends.Postgres) { val batchSize = 100 inTransaction { pg => - using(pg.prepareStatement(s"DELETE FROM channels WHERE short_channel_id IN (${ + using(pg.prepareStatement(s"DELETE FROM network.public_channels WHERE short_channel_id IN (${ List.fill(batchSize)("?").mkString(",") })")) { statement => @@ -214,7 +234,7 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { override def addToPruned(shortChannelIds: Iterable[ShortChannelId]): Unit = withMetrics("network/add-to-pruned", DbBackends.Postgres) { inTransaction { pg => - using(pg.prepareStatement("INSERT INTO pruned VALUES (?) ON CONFLICT DO NOTHING")) { + using(pg.prepareStatement("INSERT INTO network.pruned_channels VALUES (?) ON CONFLICT DO NOTHING")) { statement => shortChannelIds.foreach(shortChannelId => { statement.setLong(1, shortChannelId.toLong) @@ -227,7 +247,7 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { override def removeFromPruned(shortChannelId: ShortChannelId): Unit = withMetrics("network/remove-from-pruned", DbBackends.Postgres) { inTransaction { pg => - using(pg.prepareStatement(s"DELETE FROM pruned WHERE short_channel_id=?")) { + using(pg.prepareStatement(s"DELETE FROM network.pruned_channels WHERE short_channel_id=?")) { statement => statement.setLong(1, shortChannelId.toLong) statement.executeUpdate() @@ -237,7 +257,7 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { override def isPruned(shortChannelId: ShortChannelId): Boolean = withMetrics("network/is-pruned", DbBackends.Postgres) { inTransaction { pg => - using(pg.prepareStatement("SELECT short_channel_id from pruned WHERE short_channel_id=?")) { statement => + using(pg.prepareStatement("SELECT short_channel_id from network.pruned_channels WHERE short_channel_id=?")) { statement => statement.setLong(1, shortChannelId.toLong) statement.executeQuery().nonEmpty } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala index c8a44174a2..883a148acd 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala @@ -30,7 +30,7 @@ import scodec.Attempt import scodec.bits.BitVector import scodec.codecs._ -import java.sql.ResultSet +import java.sql.{ResultSet, Statement} import java.util.UUID import javax.sql.DataSource import scala.concurrent.duration._ @@ -42,7 +42,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit import lock._ val DB_NAME = "payments" - val CURRENT_VERSION = 4 + val CURRENT_VERSION = 5 private val hopSummaryCodec = (("node_id" | CommonCodecs.publicKey) :: ("next_node_id" | CommonCodecs.publicKey) :: ("short_channel_id" | optional(bool, CommonCodecs.shortchannelid))).as[HopSummary] private val paymentRouteCodec = discriminated[List[HopSummary]].by(byte) @@ -53,15 +53,29 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit inTransaction { pg => using(pg.createStatement()) { statement => + + def migration45(statement: Statement): Unit = { + statement.executeUpdate("CREATE SCHEMA payments") + statement.executeUpdate("ALTER TABLE received_payments RENAME TO received") + statement.executeUpdate("ALTER TABLE received SET SCHEMA payments") + statement.executeUpdate("ALTER TABLE sent_payments RENAME TO sent") + statement.executeUpdate("ALTER TABLE sent SET SCHEMA payments") + } + getVersion(statement, DB_NAME) match { case None => - statement.executeUpdate("CREATE TABLE received_payments (payment_hash TEXT NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage TEXT NOT NULL, payment_request TEXT NOT NULL, received_msat BIGINT, created_at BIGINT NOT NULL, expire_at BIGINT NOT NULL, received_at BIGINT)") - statement.executeUpdate("CREATE TABLE sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash TEXT NOT NULL, payment_preimage TEXT, payment_type TEXT NOT NULL, amount_msat BIGINT NOT NULL, fees_msat BIGINT, recipient_amount_msat BIGINT NOT NULL, recipient_node_id TEXT NOT NULL, payment_request TEXT, payment_route BYTEA, failures BYTEA, created_at BIGINT NOT NULL, completed_at BIGINT)") + statement.executeUpdate("CREATE SCHEMA payments") + + statement.executeUpdate("CREATE TABLE payments.received (payment_hash TEXT NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage TEXT NOT NULL, payment_request TEXT NOT NULL, received_msat BIGINT, created_at BIGINT NOT NULL, expire_at BIGINT NOT NULL, received_at BIGINT)") + statement.executeUpdate("CREATE TABLE payments.sent (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash TEXT NOT NULL, payment_preimage TEXT, payment_type TEXT NOT NULL, amount_msat BIGINT NOT NULL, fees_msat BIGINT, recipient_amount_msat BIGINT NOT NULL, recipient_node_id TEXT NOT NULL, payment_request TEXT, payment_route BYTEA, failures BYTEA, created_at BIGINT NOT NULL, completed_at BIGINT)") - statement.executeUpdate("CREATE INDEX sent_parent_id_idx ON sent_payments(parent_id)") - statement.executeUpdate("CREATE INDEX sent_payment_hash_idx ON sent_payments(payment_hash)") - statement.executeUpdate("CREATE INDEX sent_created_idx ON sent_payments(created_at)") - statement.executeUpdate("CREATE INDEX received_created_idx ON received_payments(created_at)") + statement.executeUpdate("CREATE INDEX sent_parent_id_idx ON payments.sent(parent_id)") + statement.executeUpdate("CREATE INDEX sent_payment_hash_idx ON payments.sent(payment_hash)") + statement.executeUpdate("CREATE INDEX sent_created_idx ON payments.sent(created_at)") + statement.executeUpdate("CREATE INDEX received_created_idx ON payments.received(created_at)") + case Some(v@4) => + logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") + migration45(statement) case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -72,7 +86,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def addOutgoingPayment(sent: OutgoingPayment): Unit = withMetrics("payments/add-outgoing", DbBackends.Postgres) { require(sent.status == OutgoingPaymentStatus.Pending, s"outgoing payment isn't pending (${sent.status.getClass.getSimpleName})") withLock { pg => - using(pg.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, payment_type, amount_msat, recipient_amount_msat, recipient_node_id, created_at, payment_request) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement => + using(pg.prepareStatement("INSERT INTO payments.sent (id, parent_id, external_id, payment_hash, payment_type, amount_msat, recipient_amount_msat, recipient_node_id, created_at, payment_request) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement => statement.setString(1, sent.id.toString) statement.setString(2, sent.parentId.toString) statement.setString(3, sent.externalId.orNull) @@ -90,7 +104,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def updateOutgoingPayment(paymentResult: PaymentSent): Unit = withMetrics("payments/update-outgoing-sent", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("UPDATE sent_payments SET (completed_at, payment_preimage, fees_msat, payment_route) = (?, ?, ?, ?) WHERE id = ? AND completed_at IS NULL")) { statement => + using(pg.prepareStatement("UPDATE payments.sent SET (completed_at, payment_preimage, fees_msat, payment_route) = (?, ?, ?, ?) WHERE id = ? AND completed_at IS NULL")) { statement => paymentResult.parts.foreach(p => { statement.setLong(1, p.timestamp) statement.setString(2, paymentResult.paymentPreimage.toHex) @@ -106,7 +120,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def updateOutgoingPayment(paymentResult: PaymentFailed): Unit = withMetrics("payments/update-outgoing-failed", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("UPDATE sent_payments SET (completed_at, failures) = (?, ?) WHERE id = ? AND completed_at IS NULL")) { statement => + using(pg.prepareStatement("UPDATE payments.sent SET (completed_at, failures) = (?, ?) WHERE id = ? AND completed_at IS NULL")) { statement => statement.setLong(1, paymentResult.timestamp) statement.setBytes(2, paymentFailuresCodec.encode(paymentResult.failures.map(f => FailureSummary(f)).toList).require.toByteArray) statement.setString(3, paymentResult.id.toString) @@ -165,7 +179,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def getOutgoingPayment(id: UUID): Option[OutgoingPayment] = withMetrics("payments/get-outgoing", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("SELECT * FROM sent_payments WHERE id = ?")) { statement => + using(pg.prepareStatement("SELECT * FROM payments.sent WHERE id = ?")) { statement => statement.setString(1, id.toString) statement.executeQuery().map(parseOutgoingPayment).headOption } @@ -174,7 +188,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def listOutgoingPayments(parentId: UUID): Seq[OutgoingPayment] = withMetrics("payments/list-outgoing-by-parent-id", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("SELECT * FROM sent_payments WHERE parent_id = ? ORDER BY created_at")) { statement => + using(pg.prepareStatement("SELECT * FROM payments.sent WHERE parent_id = ? ORDER BY created_at")) { statement => statement.setString(1, parentId.toString) statement.executeQuery().map(parseOutgoingPayment).toSeq } @@ -183,7 +197,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def listOutgoingPayments(paymentHash: ByteVector32): Seq[OutgoingPayment] = withMetrics("payments/list-outgoing-by-payment-hash", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("SELECT * FROM sent_payments WHERE payment_hash = ? ORDER BY created_at")) { statement => + using(pg.prepareStatement("SELECT * FROM payments.sent WHERE payment_hash = ? ORDER BY created_at")) { statement => statement.setString(1, paymentHash.toHex) statement.executeQuery().map(parseOutgoingPayment).toSeq } @@ -192,7 +206,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def listOutgoingPayments(from: Long, to: Long): Seq[OutgoingPayment] = withMetrics("payments/list-outgoing-by-timestamp", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("SELECT * FROM sent_payments WHERE created_at >= ? AND created_at < ? ORDER BY created_at")) { statement => + using(pg.prepareStatement("SELECT * FROM payments.sent WHERE created_at >= ? AND created_at < ? ORDER BY created_at")) { statement => statement.setLong(1, from) statement.setLong(2, to) statement.executeQuery().map { rs => @@ -204,7 +218,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def addIncomingPayment(pr: PaymentRequest, preimage: ByteVector32, paymentType: String): Unit = withMetrics("payments/add-incoming", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("INSERT INTO received_payments (payment_hash, payment_preimage, payment_type, payment_request, created_at, expire_at) VALUES (?, ?, ?, ?, ?, ?)")) { statement => + using(pg.prepareStatement("INSERT INTO payments.received (payment_hash, payment_preimage, payment_type, payment_request, created_at, expire_at) VALUES (?, ?, ?, ?, ?, ?)")) { statement => statement.setString(1, pr.paymentHash.toHex) statement.setString(2, preimage.toHex) statement.setString(3, paymentType) @@ -218,7 +232,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: Long): Unit = withMetrics("payments/receive-incoming", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("UPDATE received_payments SET (received_msat, received_at) = (? + COALESCE(received_msat, 0), ?) WHERE payment_hash = ?")) { update => + using(pg.prepareStatement("UPDATE payments.received SET (received_msat, received_at) = (? + COALESCE(received_msat, 0), ?) WHERE payment_hash = ?")) { update => update.setLong(1, amount.toLong) update.setLong(2, receivedAt) update.setString(3, paymentHash.toHex) @@ -250,7 +264,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] = withMetrics("payments/get-incoming", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("SELECT * FROM received_payments WHERE payment_hash = ?")) { statement => + using(pg.prepareStatement("SELECT * FROM payments.received WHERE payment_hash = ?")) { statement => statement.setString(1, paymentHash.toHex) statement.executeQuery().map(parseIncomingPayment).headOption } @@ -259,7 +273,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def listIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = withMetrics("payments/list-incoming", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("SELECT * FROM received_payments WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement => + using(pg.prepareStatement("SELECT * FROM payments.received WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement => statement.setLong(1, from) statement.setLong(2, to) statement.executeQuery().map(parseIncomingPayment).toSeq @@ -269,7 +283,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def listReceivedIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = withMetrics("payments/list-incoming-received", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("SELECT * FROM received_payments WHERE received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at")) { statement => + using(pg.prepareStatement("SELECT * FROM payments.received WHERE received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at")) { statement => statement.setLong(1, from) statement.setLong(2, to) statement.executeQuery().map(parseIncomingPayment).toSeq @@ -279,7 +293,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def listPendingIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = withMetrics("payments/list-incoming-pending", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("SELECT * FROM received_payments WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at > ? ORDER BY created_at")) { statement => + using(pg.prepareStatement("SELECT * FROM payments.received WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at > ? ORDER BY created_at")) { statement => statement.setLong(1, from) statement.setLong(2, to) statement.setLong(3, System.currentTimeMillis) @@ -290,7 +304,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def listExpiredIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = withMetrics("payments/list-incoming-expired", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("SELECT * FROM received_payments WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at < ? ORDER BY created_at")) { statement => + using(pg.prepareStatement("SELECT * FROM payments.received WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at < ? ORDER BY created_at")) { statement => statement.setLong(1, from) statement.setLong(2, to) statement.setLong(3, System.currentTimeMillis) @@ -300,7 +314,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit } override def listPaymentsOverview(limit: Int): Seq[PlainPayment] = withMetrics("payments/list-overview", DbBackends.Postgres) { - // This query is an UNION of the ``sent_payments`` and ``received_payments`` table + // This query is an UNION of the ``payments.sent`` and ``payments.received`` table // - missing fields set to NULL when needed. // - only retrieve incoming payments that did receive funds. // - outgoing payments are grouped by parent_id. @@ -321,7 +335,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit | received_at as completed_at, | expire_at, | NULL as order_trick - | FROM received_payments + | FROM payments.received | WHERE received_msat > 0 |UNION ALL | SELECT 'sent' as type, @@ -336,7 +350,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit | completed_at, | NULL as expire_at, | MAX(coalesce(completed_at, created_at)) as order_trick - | FROM sent_payments + | FROM payments.sent | GROUP BY parent_id,external_id,payment_hash,payment_preimage,payment_type,payment_request,created_at,completed_at |) q |ORDER BY coalesce(q.completed_at, q.created_at) DESC diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala index a26b1c93c5..035c443758 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala @@ -23,24 +23,36 @@ import fr.acinq.eclair.db.Monitoring.Tags.DbBackends import fr.acinq.eclair.db.PeersDb import fr.acinq.eclair.db.pg.PgUtils.PgLock import fr.acinq.eclair.wire.protocol._ +import grizzled.slf4j.Logging import scodec.bits.BitVector +import java.sql.Statement import javax.sql.DataSource -class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb { +class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logging { import PgUtils.ExtendedResultSet._ import PgUtils._ import lock._ val DB_NAME = "peers" - val CURRENT_VERSION = 1 + val CURRENT_VERSION = 2 inTransaction { pg => + + def migration12(statement: Statement): Unit = { + statement.executeUpdate("CREATE SCHEMA IF NOT EXISTS local") + statement.executeUpdate("ALTER TABLE peers SET SCHEMA local") + } + using(pg.createStatement()) { statement => getVersion(statement, DB_NAME) match { case None => - statement.executeUpdate("CREATE TABLE peers (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)") + statement.executeUpdate("CREATE SCHEMA IF NOT EXISTS local") + statement.executeUpdate("CREATE TABLE local.peers (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)") + case Some(v@1) => + logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") + migration12(statement) case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -53,7 +65,7 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb { val data = CommonCodecs.nodeaddress.encode(nodeaddress).require.toByteArray using(pg.prepareStatement( """ - | INSERT INTO peers (node_id, data) + | INSERT INTO local.peers (node_id, data) | VALUES (?, ?) | ON CONFLICT (node_id) | DO UPDATE SET data = EXCLUDED.data ; @@ -67,7 +79,7 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb { override def removePeer(nodeId: Crypto.PublicKey): Unit = withMetrics("peers/remove", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("DELETE FROM peers WHERE node_id=?")) { statement => + using(pg.prepareStatement("DELETE FROM local.peers WHERE node_id=?")) { statement => statement.setString(1, nodeId.value.toHex) statement.executeUpdate() } @@ -76,7 +88,7 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb { override def getPeer(nodeId: PublicKey): Option[NodeAddress] = withMetrics("peers/get", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("SELECT data FROM peers WHERE node_id=?")) { statement => + using(pg.prepareStatement("SELECT data FROM local.peers WHERE node_id=?")) { statement => statement.setString(1, nodeId.value.toHex) statement.executeQuery() .mapCodec(CommonCodecs.nodeaddress) @@ -88,7 +100,7 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb { override def listPeers(): Map[PublicKey, NodeAddress] = withMetrics("peers/list", DbBackends.Postgres) { withLock { pg => using(pg.createStatement()) { statement => - statement.executeQuery("SELECT node_id, data FROM peers") + statement.executeQuery("SELECT node_id, data FROM local.peers") .map { rs => val nodeid = PublicKey(rs.getByteVectorFromHex("node_id")) val nodeaddress = CommonCodecs.nodeaddress.decode(BitVector(rs.getBytes("data"))).require.value diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPendingCommandsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPendingCommandsDb.scala index dee8347273..a807eacf15 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPendingCommandsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPendingCommandsDb.scala @@ -36,7 +36,7 @@ class PgPendingCommandsDb(implicit ds: DataSource, lock: PgLock) extends Pending import lock._ val DB_NAME = "pending_relay" - val CURRENT_VERSION = 2 + val CURRENT_VERSION = 3 inTransaction { pg => using(pg.createStatement()) { statement => @@ -45,13 +45,23 @@ class PgPendingCommandsDb(implicit ds: DataSource, lock: PgLock) extends Pending statement.executeUpdate("ALTER TABLE pending_relay RENAME TO pending_settlement_commands") } + def migration23(statement: Statement): Unit = { + statement.executeUpdate("CREATE SCHEMA IF NOT EXISTS local") + statement.executeUpdate("ALTER TABLE pending_settlement_commands SET SCHEMA local") + } + getVersion(statement, DB_NAME) match { case None => + statement.executeUpdate("CREATE SCHEMA IF NOT EXISTS local") // note: should we use a foreign key to local_channels table here? - statement.executeUpdate("CREATE TABLE pending_settlement_commands (channel_id TEXT NOT NULL, htlc_id BIGINT NOT NULL, data BYTEA NOT NULL, PRIMARY KEY(channel_id, htlc_id))") + statement.executeUpdate("CREATE TABLE local.pending_settlement_commands (channel_id TEXT NOT NULL, htlc_id BIGINT NOT NULL, data BYTEA NOT NULL, PRIMARY KEY(channel_id, htlc_id))") case Some(v@1) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") migration12(statement) + migration23(statement) + case Some(v@2) => + logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") + migration23(statement) case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -61,7 +71,7 @@ class PgPendingCommandsDb(implicit ds: DataSource, lock: PgLock) extends Pending override def addSettlementCommand(channelId: ByteVector32, cmd: HtlcSettlementCommand): Unit = withMetrics("pending-relay/add", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("INSERT INTO pending_settlement_commands VALUES (?, ?, ?) ON CONFLICT DO NOTHING")) { statement => + using(pg.prepareStatement("INSERT INTO local.pending_settlement_commands VALUES (?, ?, ?) ON CONFLICT DO NOTHING")) { statement => statement.setString(1, channelId.toHex) statement.setLong(2, cmd.id) statement.setBytes(3, cmdCodec.encode(cmd).require.toByteArray) @@ -72,7 +82,7 @@ class PgPendingCommandsDb(implicit ds: DataSource, lock: PgLock) extends Pending override def removeSettlementCommand(channelId: ByteVector32, htlcId: Long): Unit = withMetrics("pending-relay/remove", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("DELETE FROM pending_settlement_commands WHERE channel_id=? AND htlc_id=?")) { statement => + using(pg.prepareStatement("DELETE FROM local.pending_settlement_commands WHERE channel_id=? AND htlc_id=?")) { statement => statement.setString(1, channelId.toHex) statement.setLong(2, htlcId) statement.executeUpdate() @@ -82,7 +92,7 @@ class PgPendingCommandsDb(implicit ds: DataSource, lock: PgLock) extends Pending override def listSettlementCommands(channelId: ByteVector32): Seq[HtlcSettlementCommand] = withMetrics("pending-relay/list-channel", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("SELECT htlc_id, data FROM pending_settlement_commands WHERE channel_id=?")) { statement => + using(pg.prepareStatement("SELECT htlc_id, data FROM local.pending_settlement_commands WHERE channel_id=?")) { statement => statement.setString(1, channelId.toHex) statement.executeQuery() .mapCodec(cmdCodec).toSeq @@ -92,7 +102,7 @@ class PgPendingCommandsDb(implicit ds: DataSource, lock: PgLock) extends Pending override def listSettlementCommands(): Seq[(ByteVector32, HtlcSettlementCommand)] = withMetrics("pending-relay/list", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("SELECT channel_id, data FROM pending_settlement_commands")) { statement => + using(pg.prepareStatement("SELECT channel_id, data FROM local.pending_settlement_commands")) { statement => statement.executeQuery() .map(rs => (rs.getByteVector32FromHex("channel_id"), cmdCodec.decode(rs.getByteVector("data").bits).require.value)) .toSeq diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala index 3cd5e07819..59b4418ac0 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala @@ -7,6 +7,7 @@ import fr.acinq.eclair.db._ import fr.acinq.eclair.db.pg.PgUtils.PgLock.LockFailureHandler import fr.acinq.eclair.db.pg.PgUtils.{PgLock, getVersion, using} import org.postgresql.jdbc.PgConnection +import org.scalatest.Assertions.convertToEqualizer import org.sqlite.SQLiteConnection import java.io.File @@ -91,7 +92,7 @@ object TestDatabases { val _ = dbs.db // check that db version was updated using(connection.createStatement()) { statement => - assert(getVersion(statement, dbName).contains(targetVersion)) + assert(getVersion(statement, dbName).contains(targetVersion), "unexpected version post-migration") } // post-migration checks postCheck(connection) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala index 8f61673ebd..31fb7bb4ec 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala @@ -386,7 +386,7 @@ class AuditDbSpec extends AnyFunSuite { ) } - test("migrate audit database v4 -> v5/v6") { + test("migrate audit database v4 -> v5/v7") { val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32(), randomBytes32(), randomBytes32(), 105) val relayed2 = TrampolinePaymentRelayed(randomBytes32(), Seq(PaymentRelayed.Part(300 msat, randomBytes32()), PaymentRelayed.Part(350 msat, randomBytes32())), Seq(PaymentRelayed.Part(600 msat, randomBytes32())), PlaceHolderPubKey, 0 msat, 110) @@ -458,17 +458,15 @@ class AuditDbSpec extends AnyFunSuite { } }, dbName = "audit", - targetVersion = 6, + targetVersion = 7, postCheck = connection => { val migratedDb = dbs.audit - using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(6)) - } + assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2)) val postMigrationDb = new PgAuditDb()(dbs.datasource) using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(6)) + assert(getVersion(statement, "audit").contains(7)) } val relayed3 = TrampolinePaymentRelayed(randomBytes32(), Seq(PaymentRelayed.Part(450 msat, randomBytes32()), PaymentRelayed.Part(500 msat, randomBytes32())), Seq(PaymentRelayed.Part(800 msat, randomBytes32())), randomKey().publicKey, 700 msat, 150) postMigrationDb.add(relayed3) @@ -566,8 +564,9 @@ class AuditDbSpec extends AnyFunSuite { val db = dbs.audit val sqlite = dbs.connection val isPg = dbs.isInstanceOf[TestPgDatabases] + val table = if (isPg) "audit.relayed" else "relayed" - using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement => + using(sqlite.prepareStatement(s"INSERT INTO $table (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement => if (isPg) statement.setString(1, randomBytes32().toHex) else statement.setBytes(1, randomBytes32().toArray) statement.setLong(2, 42) if (isPg) statement.setString(3, randomBytes32().toHex) else statement.setBytes(3, randomBytes32().toArray) @@ -577,7 +576,7 @@ class AuditDbSpec extends AnyFunSuite { statement.executeUpdate() } - using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement => + using(sqlite.prepareStatement(s"INSERT INTO $table (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement => if (isPg) statement.setString(1, randomBytes32().toHex) else statement.setBytes(1, randomBytes32().toArray) statement.setLong(2, 51) if (isPg) statement.setString(3, randomBytes32().toHex) else statement.setBytes(3, randomBytes32().toArray) @@ -590,7 +589,7 @@ class AuditDbSpec extends AnyFunSuite { val paymentHash = randomBytes32() val channelId = randomBytes32() - using(sqlite.prepareStatement("INSERT INTO relayed (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement => + using(sqlite.prepareStatement(s"INSERT INTO $table (payment_hash, amount_msat, channel_id, direction, relay_type, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement => if (isPg) statement.setString(1, paymentHash.toHex) else statement.setBytes(1, paymentHash.toArray) statement.setLong(2, 65) if (isPg) statement.setString(3, channelId.toHex) else statement.setBytes(3, channelId.toArray) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala index ee800bd772..bb589fd11a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala @@ -199,7 +199,7 @@ class ChannelsDbSpec extends AnyFunSuite { } } - test("migrate channel database v2 -> v3/v5") { + test("migrate channel database v2 -> v3/v6") { def postCheck(channelsDb: ChannelsDb): Unit = { assert(channelsDb.listLocalChannels().size === testCases.filterNot(_.isClosed).size) for (testCase <- testCases.filterNot(_.isClosed)) { @@ -242,7 +242,7 @@ class ChannelsDbSpec extends AnyFunSuite { } }, dbName = "channels", - targetVersion = 5, + targetVersion = 6, postCheck = _ => postCheck(dbs.channels) ) case dbs: TestSqliteDatabases => @@ -283,7 +283,7 @@ class ChannelsDbSpec extends AnyFunSuite { } } - test("migrate pg channel database v3->v5") { + test("migrate pg channel database v3->v6") { val dbs = TestPgDatabases() migrationCheck( @@ -312,7 +312,7 @@ class ChannelsDbSpec extends AnyFunSuite { } }, dbName = "channels", - targetVersion = 5, + targetVersion = 6, postCheck = connection => { assert(dbs.channels.listLocalChannels().size === testCases.filterNot(_.isClosed).size) testCases.foreach { testCase => @@ -331,10 +331,10 @@ class ChannelsDbSpec extends AnyFunSuite { val db = dbs.channels val channel = ChannelCodecsSpec.normal db.addOrUpdateChannel(channel) - dbs.connection.execSQLUpdate("UPDATE local_channels SET json='{}'") + dbs.connection.execSQLUpdate("UPDATE local.channels SET json='{}'") db.asInstanceOf[PgChannelsDb].resetJsonColumns(dbs.connection) assert({ - val res = dbs.connection.execSQLQuery("SELECT * FROM local_channels") + val res = dbs.connection.execSQLQuery("SELECT * FROM local.channels") res.next() res.getString("json").length > 100 }) @@ -387,7 +387,7 @@ object ChannelsDbSpec { } def getPgTimestamp(connection: Connection, channelId: ByteVector32, columnName: String): Option[Long] = { - using(connection.prepareStatement(s"SELECT $columnName FROM local_channels WHERE channel_id=?")) { statement => + using(connection.prepareStatement(s"SELECT $columnName FROM local.channels WHERE channel_id=?")) { statement => statement.setString(1, channelId.toHex) val rs = statement.executeQuery() rs.next() diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/NetworkDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/NetworkDbSpec.scala index 25de7a3af3..fdc8c7c46a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/NetworkDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/NetworkDbSpec.scala @@ -286,7 +286,7 @@ class NetworkDbSpec extends AnyFunSuite { ) } - test("migration test 2->3 (postgres)") { + test("migration test 2->4 (postgres)") { val dbs = TestPgDatabases() migrationCheck( dbs = dbs, @@ -317,7 +317,7 @@ class NetworkDbSpec extends AnyFunSuite { } }, dbName = "network", - targetVersion = 3, + targetVersion = 4, postCheck = _ => { assert(dbs.network.listNodes().toSet === nodeTestCases.map(_.node).toSet) // NB: channel updates are not migrated @@ -335,16 +335,16 @@ class NetworkDbSpec extends AnyFunSuite { t.update_1_opt.foreach(db.updateChannel) t.update_2_opt.foreach(db.updateChannel) } - dbs.connection.execSQLUpdate("UPDATE nodes SET json='{}'") - dbs.connection.execSQLUpdate("UPDATE channels SET channel_announcement_json='{}',channel_update_1_json=NULL,channel_update_2_json=NULL") + dbs.connection.execSQLUpdate("UPDATE network.nodes SET json='{}'") + dbs.connection.execSQLUpdate("UPDATE network.public_channels SET channel_announcement_json='{}',channel_update_1_json=NULL,channel_update_2_json=NULL") db.asInstanceOf[PgNetworkDb].resetJsonColumns(dbs.connection) assert({ - val res = dbs.connection.execSQLQuery("SELECT * FROM nodes") + val res = dbs.connection.execSQLQuery("SELECT * FROM network.nodes") res.next() res.getString("json").length > 100 }) assert({ - val res = dbs.connection.execSQLQuery("SELECT * FROM channels WHERE channel_update_1_json IS NOT NULL") + val res = dbs.connection.execSQLQuery("SELECT * FROM network.public_channels WHERE channel_update_1_json IS NOT NULL") res.next() res.getString("channel_announcement_json").length > 100 res.getString("channel_update_1_json").length > 100 diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/PendingCommandsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PendingCommandsDbSpec.scala index 1efa35b7de..8d10c34419 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/PendingCommandsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PendingCommandsDbSpec.scala @@ -76,7 +76,7 @@ class PendingCommandsDbSpec extends AnyFunSuite { } } - test("migrate database v1->v2") { + test("migrate database v1->v2/v3") { forAllDbs { case dbs: TestPgDatabases => migrationCheck( @@ -96,7 +96,7 @@ class PendingCommandsDbSpec extends AnyFunSuite { } }, dbName = "pending_relay", - targetVersion = 2, + targetVersion = 3, postCheck = _ => assert(dbs.pendingCommands.listSettlementCommands().toSet === testCases.map(tc => tc.channelId -> tc.cmd)) ) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/PgUtilsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PgUtilsSpec.scala index e8f2f8706e..985ecdee2d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/PgUtilsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PgUtilsSpec.scala @@ -5,6 +5,9 @@ import com.typesafe.config.{Config, ConfigFactory} import fr.acinq.eclair.db.pg.PgUtils.{JdbcUrlChanged, migrateTable, using} import fr.acinq.eclair.db.pg.PgUtils.PgLock.{LockFailure, LockFailureHandler} import fr.acinq.eclair.{TestKitBaseClass, TestUtils} +import grizzled.slf4j.Logging +import org.postgresql.PGConnection +import org.postgresql.jdbc.PgConnection import grizzled.slf4j.{Logger, Logging} import org.scalatest.concurrent.Eventually import org.scalatest.funsuite.AnyFunSuiteLike @@ -85,6 +88,16 @@ class PgUtilsSpec extends TestKitBaseClass with AnyFunSuiteLike with Eventually pg.close() } + test("grant rights to read-only user") { + val pg = EmbeddedPostgres.start() + pg.getPostgresDatabase.getConnection.asInstanceOf[PgConnection].execSQLUpdate("CREATE ROLE readonly NOLOGIN") + val config = ConfigFactory.parseString("postgres.readonly-user = readonly") + .withFallback(PgUtilsSpec.testConfig(pg.getPort)) + val datadir = new File(TestUtils.BUILD_DIRECTORY, s"pg_test_${UUID.randomUUID()}") + datadir.mkdirs() + Databases.postgres(config, UUID.randomUUID(), datadir, LockFailureHandler.logAndThrow) + } + test("migration test") { val pg = EmbeddedPostgres.start() using(pg.getPostgresDatabase.getConnection.createStatement()) { statement =>