Files
2025-12-29 07:34:20 +01:00

445 lines
17 KiB
Python

import os
import stat
import uuid
import gzip
import struct
import subprocess
import threading
import time
from datetime import datetime
from flask import Flask, render_template, request, jsonify, session
import paramiko
# =============================================================================
# Configuration
# =============================================================================
class Config:
SCP_HOST = os.getenv('SCP_HOST', 'localhost')
SCP_PORT = int(os.getenv('SCP_PORT', 22))
SCP_USERNAME = os.getenv('SCP_USERNAME', 'root')
SCP_DEFAULT_PATH = os.getenv('SCP_DEFAULT_PATH', '/backups')
DB_HOST = os.getenv('DB_HOST', 'db')
DB_PORT = int(os.getenv('DB_PORT', 3306))
DB_USER = os.getenv('DB_USER', 'root')
DB_PASSWORD = os.getenv('DB_PASSWORD', '')
DB_AVAILABLE = os.getenv('DB_AVAILABLE', 'thetool,addressdb').split(',')
DOWNLOAD_PATH = '/app/downloads'
SSH_KEYS_PATH = '/app/ssh-keys'
SECRET_KEY = os.getenv('SECRET_KEY', os.urandom(24).hex())
# =============================================================================
# SFTP Client
# =============================================================================
class SFTPClient:
def __init__(self, host, port, username):
self.host = host
self.port = port
self.username = username
self.client = None
self.sftp = None
def connect_password(self, password):
self.client = paramiko.SSHClient()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.client.connect(
hostname=self.host, port=self.port, username=self.username,
password=password, look_for_keys=False, allow_agent=False
)
self.sftp = self.client.open_sftp()
def connect_key(self, key_path, passphrase=None):
self.client = paramiko.SSHClient()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.client.connect(
hostname=self.host, port=self.port, username=self.username,
key_filename=key_path, passphrase=passphrase,
look_for_keys=False, allow_agent=False
)
self.sftp = self.client.open_sftp()
def list_directory(self, path):
entries = []
for entry in self.sftp.listdir_attr(path):
is_dir = stat.S_ISDIR(entry.st_mode)
entries.append({
'name': entry.filename,
'size': entry.st_size,
'size_human': self._human_size(entry.st_size),
'mtime': entry.st_mtime,
'mtime_human': datetime.fromtimestamp(entry.st_mtime).strftime('%Y-%m-%d %H:%M'),
'is_dir': is_dir,
'is_sql': entry.filename.endswith(('.sql', '.sql.gz')),
'path': os.path.join(path, entry.filename)
})
return sorted(entries, key=lambda x: (not x['is_dir'], -x['mtime']))
def get_file_info(self, path):
entry = self.sftp.stat(path)
return {
'name': os.path.basename(path),
'size': entry.st_size,
'size_human': self._human_size(entry.st_size),
'mtime': entry.st_mtime,
'mtime_human': datetime.fromtimestamp(entry.st_mtime).strftime('%Y-%m-%d %H:%M'),
'path': path
}
def download_file(self, remote_path, local_path, callback=None):
self.sftp.get(remote_path, local_path, callback=callback)
def close(self):
if self.sftp:
self.sftp.close()
if self.client:
self.client.close()
@staticmethod
def _human_size(size):
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if size < 1024:
return f"{size:.1f} {unit}"
size /= 1024
return f"{size:.1f} PB"
# =============================================================================
# Database Restore
# =============================================================================
class DatabaseRestore:
def __init__(self):
self.host = Config.DB_HOST
self.port = Config.DB_PORT
self.user = Config.DB_USER
self.password = Config.DB_PASSWORD
self.available_dbs = Config.DB_AVAILABLE
self.cancelled = False
def cancel(self):
self.cancelled = True
@staticmethod
def get_gzip_uncompressed_size(filepath):
with open(filepath, 'rb') as f:
f.seek(-4, 2)
return struct.unpack('<I', f.read(4))[0]
def _mysql_cmd(self, *extra_args):
return ['mysql', '-h', self.host, '-P', str(self.port), '-u', self.user, f'-p{self.password}'] + list(extra_args)
def ensure_database_exists(self, target_db):
if target_db not in self.available_dbs:
raise ValueError(f"Invalid database: {target_db}")
cmd = self._mysql_cmd('-e', f"CREATE DATABASE IF NOT EXISTS `{target_db}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise Exception(f"Failed to create database: {result.stderr}")
def clear_database(self, target_db):
cmd = self._mysql_cmd('-N', '-e', f"SELECT table_name FROM information_schema.tables WHERE table_schema='{target_db}'")
result = subprocess.run(cmd, capture_output=True, text=True)
tables = [t.strip() for t in result.stdout.strip().split('\n') if t.strip()]
if tables:
drop_sql = "SET FOREIGN_KEY_CHECKS=0; " + "; ".join(f"DROP TABLE IF EXISTS `{t}`" for t in tables) + "; SET FOREIGN_KEY_CHECKS=1;"
subprocess.run(self._mysql_cmd(target_db, '-e', drop_sql), check=True, capture_output=True)
return len(tables)
def restore_from_file(self, file_path, target_db, progress_callback=None):
if target_db not in self.available_dbs:
raise ValueError(f"Invalid database: {target_db}")
self.cancelled = False
self.ensure_database_exists(target_db)
tables_dropped = self.clear_database(target_db)
if self.cancelled:
raise Exception("Restore cancelled by user")
mysql_cmd = self._mysql_cmd(target_db)
process = None
try:
if file_path.endswith('.gz'):
with gzip.open(file_path, 'rb') as f:
process = subprocess.Popen(mysql_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
bytes_read = 0
while True:
if self.cancelled:
process.terminate()
raise Exception("Restore cancelled by user")
chunk = f.read(1024 * 1024)
if not chunk:
break
if process.poll() is not None:
raise Exception(f"MySQL terminated: {process.stderr.read().decode()}")
try:
process.stdin.write(chunk)
process.stdin.flush()
except BrokenPipeError:
raise Exception(f"MySQL connection lost: {process.stderr.read().decode()}")
bytes_read += len(chunk)
if progress_callback:
progress_callback(bytes_read)
process.stdin.close()
process.wait(timeout=300)
if process.returncode != 0:
raise Exception(f"MySQL restore failed: {process.stderr.read().decode()}")
else:
with open(file_path, 'rb') as f:
result = subprocess.run(mysql_cmd, stdin=f, capture_output=True, timeout=600)
if result.returncode != 0:
raise Exception(f"MySQL restore failed: {result.stderr.decode()}")
except subprocess.TimeoutExpired:
if process:
process.kill()
raise Exception("MySQL restore timed out")
return {'tables_dropped': tables_dropped, 'file': os.path.basename(file_path), 'database': target_db}
# =============================================================================
# Flask Application
# =============================================================================
app = Flask(__name__)
app.config['SECRET_KEY'] = Config.SECRET_KEY
# Job storage
jobs = {}
restorers = {}
@app.route('/')
def index():
return render_template('index.html', databases=Config.DB_AVAILABLE, scp_host=Config.SCP_HOST, scp_username=Config.SCP_USERNAME)
@app.route('/health')
def health():
return {'status': 'ok'}
@app.route('/api/keys', methods=['GET'])
def list_keys():
keys = []
if os.path.exists(Config.SSH_KEYS_PATH):
keys = [f for f in os.listdir(Config.SSH_KEYS_PATH) if not f.endswith('.pub') and not f.startswith('.')]
return jsonify({'success': True, 'keys': keys})
@app.route('/api/connect', methods=['POST'])
def connect():
data = request.json
auth_type = data.get('auth_type', 'password')
try:
client = SFTPClient(Config.SCP_HOST, Config.SCP_PORT, Config.SCP_USERNAME)
if auth_type == 'password':
if not data.get('password'):
return jsonify({'success': False, 'error': 'Password is required'}), 400
client.connect_password(data['password'])
else:
if not data.get('key_file'):
return jsonify({'success': False, 'error': 'SSH key file is required'}), 400
key_path = os.path.join(Config.SSH_KEYS_PATH, data['key_file'])
if not os.path.exists(key_path):
return jsonify({'success': False, 'error': 'SSH key file not found'}), 400
client.connect_key(key_path, data.get('key_passphrase'))
files = client.list_directory(Config.SCP_DEFAULT_PATH)
client.close()
session['sftp_auth'] = {
'type': auth_type,
'password': data.get('password'),
'key_file': data.get('key_file'),
'key_passphrase': data.get('key_passphrase')
}
session['connected'] = True
return jsonify({'success': True, 'files': files, 'path': Config.SCP_DEFAULT_PATH, 'host': Config.SCP_HOST, 'username': Config.SCP_USERNAME})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 400
@app.route('/api/browse', methods=['POST'])
def browse():
if not session.get('connected'):
return jsonify({'success': False, 'error': 'Not connected'}), 401
auth = session.get('sftp_auth')
if not auth:
return jsonify({'success': False, 'error': 'Not authenticated'}), 401
path = request.json.get('path', Config.SCP_DEFAULT_PATH)
try:
client = SFTPClient(Config.SCP_HOST, Config.SCP_PORT, Config.SCP_USERNAME)
if auth['type'] == 'password':
client.connect_password(auth['password'])
else:
client.connect_key(os.path.join(Config.SSH_KEYS_PATH, auth['key_file']), auth.get('key_passphrase'))
files = client.list_directory(path)
client.close()
return jsonify({'success': True, 'files': files, 'path': path})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 400
@app.route('/api/disconnect', methods=['POST'])
def disconnect():
session.pop('sftp_auth', None)
session.pop('connected', None)
return jsonify({'success': True})
@app.route('/api/databases', methods=['GET'])
def list_databases():
return jsonify({'success': True, 'databases': Config.DB_AVAILABLE})
@app.route('/api/restore', methods=['POST'])
def restore():
if not session.get('connected'):
return jsonify({'success': False, 'error': 'Not connected to SFTP'}), 401
data = request.json
remote_file = data.get('file')
target_db = data.get('database')
if not remote_file:
return jsonify({'success': False, 'error': 'No file selected'}), 400
if target_db not in Config.DB_AVAILABLE:
return jsonify({'success': False, 'error': f'Invalid database'}), 400
auth = session.get('sftp_auth')
if not auth:
return jsonify({'success': False, 'error': 'Not authenticated'}), 401
job_id = str(uuid.uuid4())
jobs[job_id] = {
'status': 'starting', 'progress': 0, 'file': os.path.basename(remote_file),
'database': target_db, 'started_at': time.time(), 'message': 'Initializing...'
}
thread = threading.Thread(target=run_restore, args=(job_id, remote_file, target_db, dict(auth)))
thread.daemon = True
thread.start()
return jsonify({'success': True, 'job_id': job_id})
def run_restore(job_id, remote_file, target_db, auth):
local_file = os.path.join(Config.DOWNLOAD_PATH, os.path.basename(remote_file))
try:
if jobs[job_id].get('cancelled'):
jobs[job_id].update({'status': 'cancelled', 'message': 'Restore cancelled by user'})
return
jobs[job_id].update({'status': 'downloading', 'message': 'Connecting to remote server...'})
client = SFTPClient(Config.SCP_HOST, Config.SCP_PORT, Config.SCP_USERNAME)
if auth['type'] == 'password':
client.connect_password(auth['password'])
else:
client.connect_key(os.path.join(Config.SSH_KEYS_PATH, auth['key_file']), auth.get('key_passphrase'))
file_info = client.get_file_info(remote_file)
jobs[job_id]['file_size'] = file_info['size_human']
jobs[job_id]['message'] = f'Downloading {file_info["size_human"]}...'
def download_progress(transferred, total):
if jobs[job_id].get('cancelled'):
raise Exception("Download cancelled by user")
jobs[job_id].update({'progress': int((transferred / total) * 45) if total > 0 else 0, 'downloaded': transferred, 'total': total})
client.download_file(remote_file, local_file, callback=download_progress)
client.close()
if jobs[job_id].get('cancelled'):
jobs[job_id].update({'status': 'cancelled', 'message': 'Restore cancelled by user'})
if os.path.exists(local_file):
os.remove(local_file)
return
jobs[job_id].update({'progress': 45, 'message': 'Download complete. Preparing restore...', 'status': 'restoring'})
jobs[job_id]['progress'] = 50
jobs[job_id]['message'] = f'Clearing database {target_db}...'
restorer = DatabaseRestore()
restorers[job_id] = restorer
uncompressed_size = restorer.get_gzip_uncompressed_size(local_file) if local_file.endswith('.gz') else os.path.getsize(local_file)
def restore_progress(bytes_processed):
if jobs[job_id].get('cancelled'):
restorer.cancel()
pct = 50 + min(45, int((bytes_processed / uncompressed_size) * 45)) if uncompressed_size > 0 else 50
jobs[job_id].update({'progress': pct, 'message': f'Restoring to {target_db}... ({bytes_processed // (1024*1024)} MB / {uncompressed_size // (1024*1024)} MB)'})
result = restorer.restore_from_file(local_file, target_db, progress_callback=restore_progress)
if os.path.exists(local_file):
os.remove(local_file)
jobs[job_id].update({
'status': 'completed', 'progress': 100,
'message': f'Restore complete! Dropped {result["tables_dropped"]} tables and imported {result["file"]}',
'completed_at': time.time(), 'duration': time.time() - jobs[job_id]['started_at']
})
except Exception as e:
error_msg = str(e)
if 'cancelled' in error_msg.lower():
jobs[job_id].update({'status': 'cancelled', 'message': 'Restore cancelled by user'})
else:
jobs[job_id].update({'status': 'error', 'error': error_msg, 'message': f'Error: {error_msg}'})
if os.path.exists(local_file):
os.remove(local_file)
finally:
restorers.pop(job_id, None)
@app.route('/api/status/<job_id>')
def status(job_id):
if job_id not in jobs:
return jsonify({'success': False, 'error': 'Job not found'}), 404
job = jobs[job_id].copy()
job['success'] = True
if 'started_at' in job:
elapsed = (job.get('completed_at') or time.time()) - job['started_at']
job['elapsed'] = f'{int(elapsed // 60)}m {int(elapsed % 60)}s'
return jsonify(job)
@app.route('/api/jobs', methods=['GET'])
def list_jobs():
return jsonify({'success': True, 'jobs': dict(jobs)})
@app.route('/api/cancel/<job_id>', methods=['POST'])
def cancel(job_id):
if job_id not in jobs:
return jsonify({'success': False, 'error': 'Job not found'}), 404
if jobs[job_id]['status'] in ('completed', 'error', 'cancelled'):
return jsonify({'success': False, 'error': 'Job already finished'}), 400
jobs[job_id]['cancelled'] = True
if job_id in restorers:
restorers[job_id].cancel()
return jsonify({'success': True, 'message': 'Cancel signal sent'})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8082, debug=True)