Skip to content

Make use of the default SSH user #129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 39 additions & 49 deletions testgres/operations/remote_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
import os
import socket
import subprocess
import tempfile
import platform
Expand Down Expand Up @@ -45,46 +45,44 @@ def __init__(self, conn_params: ConnectionParams):
self.conn_params = conn_params
self.host = conn_params.host
self.ssh_key = conn_params.ssh_key
self.port = conn_params.port
self.ssh_args = []
if self.ssh_key:
self.ssh_cmd = ["-i", self.ssh_key]
else:
self.ssh_cmd = []
self.ssh_args += ["-i", self.ssh_key]
if self.port:
self.ssh_args += ["-p", self.port]
self.remote = True
self.username = conn_params.username or self.get_user()
self.username = conn_params.username
self.ssh_dest = f"{self.username}@{self.host}" if self.username else self.host
self.add_known_host(self.host)
self.tunnel_process = None
self.tunnel_port = None

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.close_ssh_tunnel()

def establish_ssh_tunnel(self, local_port, remote_port):
"""
Establish an SSH tunnel from a local port to a remote PostgreSQL port.
"""
ssh_cmd = ['-N', '-L', f"{local_port}:localhost:{remote_port}"]
self.tunnel_process = self.exec_command(ssh_cmd, get_process=True, timeout=300)
@staticmethod
def is_port_open(host, port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(1) # Таймаут для попытки соединения
Copy link
Contributor

@demonolock demonolock Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment in english, please

try:
sock.connect((host, port))
return True
except socket.error:
return False

def close_ssh_tunnel(self):
if hasattr(self, 'tunnel_process'):
if self.tunnel_process:
self.tunnel_process.terminate()
self.tunnel_process.wait()
print("SSH tunnel closed.")
del self.tunnel_process
else:
print("No active tunnel to close.")

def add_known_host(self, host):
known_hosts_path = os.path.expanduser("~/.ssh/known_hosts")
cmd = 'ssh-keyscan -H %s >> %s' % (host, known_hosts_path)

try:
subprocess.check_call(cmd, shell=True)
logging.info("Successfully added %s to known_hosts." % host)
except subprocess.CalledProcessError as e:
raise Exception("Failed to add %s to known_hosts. Error: %s" % (host, str(e)))

def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None,
stderr=None, get_process=None, timeout=None):
Expand All @@ -95,9 +93,9 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
"""
ssh_cmd = []
if isinstance(cmd, str):
ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + [cmd]
ssh_cmd = ['ssh'] + self.ssh_args + [self.ssh_dest, cmd]
elif isinstance(cmd, list):
ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + cmd
ssh_cmd = ['ssh'] + self.ssh_args + [self.ssh_dest] + cmd
process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if get_process:
return process
Expand Down Expand Up @@ -172,10 +170,6 @@ def set_env(self, var_name: str, var_val: str):
"""
return self.exec_command("export {}={}".format(var_name, var_val))

# Get environment variables
def get_user(self):
return self.exec_command("echo $USER", encoding=get_default_encoding()).strip()

def get_name(self):
cmd = 'python3 -c "import os; print(os.name)"'
return self.exec_command(cmd, encoding=get_default_encoding()).strip()
Expand Down Expand Up @@ -246,9 +240,9 @@ def mkdtemp(self, prefix=None):
- prefix (str): The prefix of the temporary directory name.
"""
if prefix:
command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"]
command = ["ssh"] + self.ssh_args + [self.ssh_dest, f"mktemp -d {prefix}XXXXX"]
else:
command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", "mktemp -d"]
command = ["ssh"] + self.ssh_args + [self.ssh_dest, "mktemp -d"]

result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

Expand Down Expand Up @@ -291,8 +285,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
mode = "r+b" if binary else "r+"

with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file:
# Because in scp we set up port using -P option
scp_args = ['-P' if x == '-p' else x for x in self.ssh_args]

if not truncate:
scp_cmd = ['scp'] + self.ssh_cmd + [f"{self.username}@{self.host}:{filename}", tmp_file.name]
scp_cmd = ['scp'] + scp_args + [f"{self.ssh_dest}:{filename}", tmp_file.name]
subprocess.run(scp_cmd, check=False) # The file might not exist yet
tmp_file.seek(0, os.SEEK_END)

Expand All @@ -308,11 +305,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
tmp_file.write(data)

tmp_file.flush()
scp_cmd = ['scp'] + self.ssh_cmd + [tmp_file.name, f"{self.username}@{self.host}:{filename}"]
scp_cmd = ['scp'] + scp_args + [tmp_file.name, f"{self.ssh_dest}:{filename}"]
subprocess.run(scp_cmd, check=True)

remote_directory = os.path.dirname(filename)
mkdir_cmd = ['ssh'] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"]
mkdir_cmd = ['ssh'] + self.ssh_args + [self.ssh_dest, f"mkdir -p {remote_directory}"]
subprocess.run(mkdir_cmd, check=True)

os.remove(tmp_file.name)
Expand Down Expand Up @@ -377,7 +374,7 @@ def get_pid(self):
return int(self.exec_command("echo $$", encoding=get_default_encoding()))

def get_process_children(self, pid):
command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"pgrep -P {pid}"]
command = ["ssh"] + self.ssh_args + [self.ssh_dest, f"pgrep -P {pid}"]

result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

Expand All @@ -389,18 +386,11 @@ def get_process_children(self, pid):

# Database control
def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
"""
Established SSH tunnel and Connects to a PostgreSQL
"""
self.establish_ssh_tunnel(local_port=port, remote_port=5432)
try:
conn = pglib.connect(
host=host,
port=port,
database=dbname,
user=user,
password=password,
)
return conn
except Exception as e:
raise Exception(f"Could not connect to the database. Error: {e}")
conn = pglib.connect(
host=host,
port=port,
database=dbname,
user=user,
password=password,
)
return conn