#!/bin/bash

# Helper script to connect to remote managed instance with SSM and start SSH port forwarding
# Every time it generates a new SSH key at ~/.ssh/sagemaker-ssh-gw and transfers the public part
# to the instance via S3 by executing the remote SSM command

set -e
set -v

INSTANCE_ID="$1"
SSH_AUTHORIZED_KEYS="$2"
shift
shift

# TODO: make it possible to override the default (also helps avoid race conditions)
SSH_KEY=~/.ssh/sagemaker-ssh-gw

echo "Generating $SSH_KEY and uploading public key to $SSH_AUTHORIZED_KEYS"

echo 'yes' | ssh-keygen -f "${SSH_KEY}" -N ''
cat "${SSH_KEY}.pub"
aws s3 cp "${SSH_KEY}.pub" "${SSH_AUTHORIZED_KEYS}"

CURRENT_REGION=$(aws configure list | grep region | awk '{print $2}')
echo "Will use AWS Region: $CURRENT_REGION"

PORT_FWD_ARGS=$*

AWS_CLI_VERSION=$(aws --version)
echo "AWS CLI version (should be v2): $AWS_CLI_VERSION"

aws --no-cli-pager ssm send-command \
    --region "${CURRENT_REGION}" \
    --instance-ids "${INSTANCE_ID}" \
    --document-name "AWS-RunShellScript" \
    --comment "Copy public key for SSH helper" \
    --parameters "commands=[
        'id',
        'aws sts get-caller-identity',
        'mkdir -p /root/.ssh/authorized_keys.d/',
        'aws s3 cp --recursive ${SSH_AUTHORIZED_KEYS} /root/.ssh/authorized_keys.d/',
        'cat /root/.ssh/authorized_keys.d/* > /root/.ssh/authorized_keys',
        'cat /root/.ssh/authorized_keys'
      ]" \
    --no-paginate --output text

# Allow enough time to complete commands
sleep 5

# TODO: fetch SSH key fingerprint from the logs and enable StrictHostKeyChecking, for better security

echo "Connecting to $INSTANCE_ID as proxy and starting port forwarding with the args: $PORT_FWD_ARGS"

# We don't use AWS-StartPortForwardingSession feature of SSM here, because we need port forwarding in both directions
#  with -L and -R parameters of SSH. This is useful for forwarding the PyCharm license server, which needs -R option.
#  SSM allows only forwarding of ports from the server (equivalent to the -L option).
# shellcheck disable=SC2086
ssh -o User=root -o IdentityFile="${SSH_KEY}" -o IdentitiesOnly=yes \
  -o ProxyCommand="aws ssm start-session --region '${CURRENT_REGION}' --target '${INSTANCE_ID}' --document-name AWS-StartSSHSession --parameters portNumber=%p" \
  -o ServerAliveInterval=15 -o ServerAliveCountMax=3 \
  -o StrictHostKeyChecking=no -N $PORT_FWD_ARGS "$INSTANCE_ID"
