Skip to content

Make use of the default SSH user #133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 7 additions & 19 deletions testgres/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@

from .defaults import \
default_dbname, \
default_username, \
generate_app_name

from .exceptions import \
Expand Down Expand Up @@ -683,8 +682,6 @@ def slow_start(self, replica=False, dbname='template1', username=None, max_attem
If False, waits for the instance to be in primary mode. Default is False.
max_attempts:
"""
if not username:
username = default_username()
self.start()

if replica:
Expand All @@ -694,7 +691,7 @@ def slow_start(self, replica=False, dbname='template1', username=None, max_attem
# Call poll_query_until until the expected value is returned
self.poll_query_until(query=query,
dbname=dbname,
username=username,
username=username or self.os_ops.username,
suppress={InternalError,
QueryException,
ProgrammingError,
Expand Down Expand Up @@ -967,15 +964,13 @@ def psql(self,
>>> psql(query='select 3', ON_ERROR_STOP=1)
"""

# Set default arguments
dbname = dbname or default_dbname()
username = username or default_username()

psql_params = [
self._get_bin_path("psql"),
"-p", str(self.port),
"-h", self.host,
"-U", username,
"-U", username or self.os_ops.username,
"-X", # no .psqlrc
"-A", # unaligned output
"-t", # print rows only
Expand Down Expand Up @@ -1087,18 +1082,15 @@ def tmpfile():
fname = self.os_ops.mkstemp(prefix=TMP_DUMP)
return fname

# Set default arguments
dbname = dbname or default_dbname()
username = username or default_username()
filename = filename or tmpfile()

_params = [
self._get_bin_path("pg_dump"),
"-p", str(self.port),
"-h", self.host,
"-f", filename,
"-U", username,
"-d", dbname,
"-U", username or self.os_ops.username,
"-d", dbname or default_dbname(),
"-F", format.value
] # yapf: disable

Expand All @@ -1118,7 +1110,7 @@ def restore(self, filename, dbname=None, username=None):

# Set default arguments
dbname = dbname or default_dbname()
username = username or default_username()
username = username or self.os_ops.username

_params = [
self._get_bin_path("pg_restore"),
Expand Down Expand Up @@ -1388,15 +1380,13 @@ def pgbench(self,
if options is None:
options = []

# Set default arguments
dbname = dbname or default_dbname()
username = username or default_username()

_params = [
self._get_bin_path("pgbench"),
"-p", str(self.port),
"-h", self.host,
"-U", username,
"-U", username or self.os_ops.username
] + options # yapf: disable

# should be the last one
Expand Down Expand Up @@ -1463,15 +1453,13 @@ def pgbench_run(self, dbname=None, username=None, options=[], **kwargs):
>>> pgbench_run(time=10)
"""

# Set default arguments
dbname = dbname or default_dbname()
username = username or default_username()

_params = [
self._get_bin_path("pgbench"),
"-p", str(self.port),
"-h", self.host,
"-U", username,
"-U", username or self.os_ops.username
] + options # yapf: disable

for key, value in iteritems(kwargs):
Expand Down
6 changes: 1 addition & 5 deletions testgres/operations/local_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, conn_params=None):
self.host = conn_params.host
self.ssh_key = None
self.remote = False
self.username = conn_params.username or self.get_user()
self.username = conn_params.username or getpass.getuser()

@staticmethod
def _raise_exec_exception(message, command, exit_code, output):
Expand Down Expand Up @@ -130,10 +130,6 @@ def set_env(self, var_name, var_val):
# Check if the directory is already in PATH
os.environ[var_name] = var_val

# Get environment variables
def get_user(self):
return self.username or getpass.getuser()

def get_name(self):
return os.name

Expand Down
3 changes: 1 addition & 2 deletions testgres/operations/os_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ def set_env(self, var_name, var_val):
# Check if the directory is already in PATH
raise NotImplementedError()

# Get environment variables
def get_user(self):
raise NotImplementedError()
return self.username

def get_name(self):
raise NotImplementedError()
Expand Down
28 changes: 13 additions & 15 deletions testgres/operations/remote_ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
import getpass
import os
import logging
import platform
import subprocess
import tempfile
import platform

# we support both pg8000 and psycopg2
try:
Expand Down Expand Up @@ -52,7 +53,8 @@ def __init__(self, conn_params: ConnectionParams):
if self.port:
self.ssh_args += ["-p", self.port]
self.remote = True
self.username = conn_params.username or self.get_user()
self.username = conn_params.username or getpass.getuser()
self.ssh_dest = f"{self.username}@{self.host}" if conn_params.username else self.host
self.add_known_host(self.host)
self.tunnel_process = None

Expand Down Expand Up @@ -97,9 +99,9 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
"""
ssh_cmd = []
if isinstance(cmd, str):
ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_args + [cmd]
ssh_cmd = ['ssh', self.ssh_dest] + self.ssh_args + [cmd]
elif isinstance(cmd, list):
ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_args + cmd
ssh_cmd = ['ssh', self.ssh_dest] + self.ssh_args + cmd
process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if get_process:
return process
Expand Down Expand Up @@ -174,10 +176,6 @@ def set_env(self, var_name: str, var_val: str):
"""
return self.exec_command("export {}={}".format(var_name, var_val))

# Get environment variables
def get_user(self):
return self.exec_command("echo $USER", encoding=get_default_encoding()).strip()

def get_name(self):
cmd = 'python3 -c "import os; print(os.name)"'
return self.exec_command(cmd, encoding=get_default_encoding()).strip()
Expand Down Expand Up @@ -248,9 +246,9 @@ def mkdtemp(self, prefix=None):
- prefix (str): The prefix of the temporary directory name.
"""
if prefix:
command = ["ssh"] + self.ssh_args + [f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"]
command = ["ssh"] + self.ssh_args + [self.ssh_dest, f"mktemp -d {prefix}XXXXX"]
else:
command = ["ssh"] + self.ssh_args + [f"{self.username}@{self.host}", "mktemp -d"]
command = ["ssh"] + self.ssh_args + [self.ssh_dest, "mktemp -d"]

result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

Expand Down Expand Up @@ -296,7 +294,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
# For scp the port is specified by a "-P" option
scp_args = ['-P' if x == '-p' else x for x in self.ssh_args]
if not truncate:
scp_cmd = ['scp'] + scp_args + [f"{self.username}@{self.host}:{filename}", tmp_file.name]
scp_cmd = ['scp'] + scp_args + [f"{self.ssh_dest}:{filename}", tmp_file.name]
subprocess.run(scp_cmd, check=False) # The file might not exist yet
tmp_file.seek(0, os.SEEK_END)

Expand All @@ -312,11 +310,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
tmp_file.write(data)

tmp_file.flush()
scp_cmd = ['scp'] + scp_args + [tmp_file.name, f"{self.username}@{self.host}:{filename}"]
scp_cmd = ['scp'] + scp_args + [tmp_file.name, f"{self.ssh_dest}:{filename}"]
subprocess.run(scp_cmd, check=True)

remote_directory = os.path.dirname(filename)
mkdir_cmd = ['ssh'] + self.ssh_args + [f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"]
mkdir_cmd = ['ssh'] + self.ssh_args + [self.ssh_dest, f"mkdir -p {remote_directory}"]
subprocess.run(mkdir_cmd, check=True)

os.remove(tmp_file.name)
Expand Down Expand Up @@ -381,7 +379,7 @@ def get_pid(self):
return int(self.exec_command("echo $$", encoding=get_default_encoding()))

def get_process_children(self, pid):
command = ["ssh"] + self.ssh_args + [f"{self.username}@{self.host}", f"pgrep -P {pid}"]
command = ["ssh"] + self.ssh_args + [self.ssh_dest, f"pgrep -P {pid}"]

result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

Expand Down