Skip to content

Commit

Permalink
Add :search_path option to connection (#620)
Browse files Browse the repository at this point in the history
  • Loading branch information
greg-rychlewski committed Sep 15, 2022
1 parent 5a1b4ec commit 61acc59
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 2 deletions.
10 changes: 10 additions & 0 deletions lib/postgrex.ex
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ defmodule Postgrex do
| {:prepare, :named | :unnamed}
| {:transactions, :strict | :naive}
| {:types, module}
| {:search_path, [String.t()]}
| {:disconnect_on_error_codes, [atom]}
| DBConnection.start_option()

Expand Down Expand Up @@ -158,6 +159,15 @@ defmodule Postgrex do
option is only required when using custom encoding or decoding (default:
`Postgrex.DefaultTypes`);
* `:search_path` - A list of strings used to set the search path for the connection.
This is useful when, for instance, an extension like `citext` is installed in a
separate schema. If that schema is not in the connection's search path, Postgrex
might not be able to recognize the extension's data type. When this option is `nil`,
the search path is not modified. (default: `nil`).
See the [PostgreSQL docs](https://www.postgresql.org/docs/current/ddl-schemas.html#DDL-SCHEMAS-PATH)
for more details.
`Postgrex` uses the `DBConnection` library and supports all `DBConnection`
options like `:idle`, `:after_connect` etc. See `DBConnection.start_link/2`
for more information.
Expand Down
67 changes: 65 additions & 2 deletions lib/postgrex/protocol.ex
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ defmodule Postgrex.Protocol do
prepare: prepare,
messages: [],
ssl: ssl?,
target_server_type: target_server_type
target_server_type: target_server_type,
search_path: opts[:search_path]
}

connect_endpoints(endpoints, sock_opts ++ @sock_opts, connect_timeout, s, status, [])
Expand Down Expand Up @@ -830,7 +831,7 @@ defmodule Postgrex.Protocol do
init_recv(%{s | connection_id: pid, connection_key: key}, status, buffer)

{:ok, msg_ready(), buffer} ->
check_target_server_type(s, status, buffer)
set_search_path(s, status, buffer)

{:ok, msg_error(fields: fields), buffer} ->
disconnect(s, Postgrex.Error.exception(postgres: fields), buffer)
Expand All @@ -844,6 +845,68 @@ defmodule Postgrex.Protocol do
end
end

## set search path on connection startup

defp set_search_path(s, %{search_path: nil} = status, buffer),
do: set_search_path_done(s, status, buffer)

defp set_search_path(s, %{search_path: search_path} = status, buffer)
when is_list(search_path),
do: set_search_path_send(s, status, buffer)

defp set_search_path(_, %{search_path: search_path}, _) do
raise ArgumentError,
"expected :search_path to be a list of strings, got: #{inspect(search_path)}"
end

defp set_search_path_send(s, status, buffer) do
search_path = Enum.intersperse(status.search_path, ",")
msg = msg_query(statement: ["set search_path to " | search_path])

case msg_send(s, msg, buffer) do
:ok ->
set_search_path_recv(s, status, buffer)

{:disconnect, _, _} = dis ->
dis
end
end

defp set_search_path_recv(s, status, buffer) do
case msg_recv(s, :infinity, buffer) do
{:ok, msg_row_desc(fields: fields), buffer} ->
{[@text_type_oid], ["search_path"]} = columns(fields)
set_search_path_recv(s, status, buffer)

{:ok, msg_data_row(), buffer} ->
set_search_path_recv(s, status, buffer)

{:ok, msg_command_complete(), buffer} ->
set_search_path_recv(s, status, buffer)

{:ok, msg_ready(status: :idle), buffer} ->
set_search_path_done(s, status, buffer)

{:ok, msg_ready(status: postgres), _buffer} ->
err = %Postgrex.Error{message: "unexpected postgres status: #{postgres}"}
{:disconnect, err, s}

{:ok, msg_error(fields: fields), buffer} ->
err = Postgrex.Error.exception(postgres: fields)
{:disconnect, err, %{s | buffer: buffer}}

{:ok, msg, buffer} ->
{s, status} = handle_msg(s, status, msg)
set_search_path_recv(s, status, buffer)

{:disconnect, _, _} = dis ->
dis
end
end

defp set_search_path_done(s, status, buffer),
do: check_target_server_type(s, status, buffer)

## check_target_server_type

defp check_target_server_type(s, %{target_server_type: :any} = status, buffer),
Expand Down
27 changes: 27 additions & 0 deletions test/postgrex_test.exs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defmodule PostgrexTest do
use ExUnit.Case, async: false
import ExUnit.CaptureLog

test "start_link/2 raises when :ssl app is required but not started" do
on_exit(fn ->
Expand All @@ -12,4 +13,30 @@ defmodule PostgrexTest do
Postgrex.start_link(ssl: true, database: "postgrex_test")
end
end

test "start_link/2 sets search path" do
# valid argument
search_path = ["public", "extension"]
{:ok, conn} = Postgrex.start_link(database: "postgrex_test", search_path: search_path)
%{rows: [[result]]} = Postgrex.query!(conn, "show search_path", [])
assert result == Enum.join(search_path, ", ")

# invalid argument
Process.flag(:trap_exit, true)
search_path = "public, extension"

opts = [
database: "postgrex_test",
search_path: search_path,
show_sensitive_data_on_connection_error: true
]

error_log =
capture_log(fn ->
Postgrex.start_link(opts)
assert_receive {:EXIT, _, :killed}
end)

assert error_log =~ "expected :search_path to be a list of strings, got: \"#{search_path}\""
end
end
14 changes: 14 additions & 0 deletions test/query_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1557,4 +1557,18 @@ defmodule QueryTest do

assert {_, %{count: 3}, _} = Table.Reader.init(res)
end

test "search_path", context do
:ok = query("CREATE SCHEMA test_schema", [])
:ok = query("CREATE TABLE test_schema.test_table (id int, text text)", [])
:ok = query("INSERT INTO test_schema.test_table VALUES (1, 'foo')", [])

# search path does not contain the appropriate schema
%Postgrex.Error{postgres: error} = query("SELECT * from test_table", [])
assert error.message =~ "\"test_table\" does not exist"

# search path does contain the appropriate schema
{:ok, pid} = P.start_link(database: "postgrex_test", search_path: ["public", "test_schema"])
%{rows: [[1, "foo"]]} = P.query!(pid, "SELECT * from test_table", [])
end
end

0 comments on commit 61acc59

Please sign in to comment.