#!/bin/bash

# Commands:
# connect <endpoint_name>
# stop-waiting
# run-command <command> <args...>

COMMAND=$1

if [[ "$COMMAND" == "connect" ]]; then
  ENDPOINT_NAME=$2

  INSTANCE_ID=$(python <<EOF
import sagemaker; from sagemaker_ssh_helper.log import SSHLog;
import logging; logging.basicConfig(level=logging.INFO);
print(SSHLog().get_endpoint_ssm_instance_ids("$ENDPOINT_NAME", retry=30)[0])
EOF
  )

  sm-local-start-ssh "$INSTANCE_ID" \
      -R localhost:12345:localhost:12345 \
      -L localhost:11022:localhost:22

elif [[ "$COMMAND" == "stop-waiting" ]]; then

  $0 run-command pkill sm-wait

elif [[ "$COMMAND" == "run-command" ]]; then

  shift
  ARGS=$*

  # shellcheck disable=SC2086
  ssh -i ~/.ssh/sagemaker-ssh-gw -p 11022 root@localhost \
    -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
    $ARGS

else
    echo "Unknown command: $COMMAND"
fi