Coverage for /home/antoine/projects/xpra-git/dist/python3/lib64/python/xpra/server/auth/sqlauthbase.py : 40%
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# This file is part of Xpra.
3# Copyright (C) 2017-2020 Antoine Martin <antoine@xpra.org>
4# Xpra is released under the terms of the GNU GPL v2, or, at your option, any
5# later version. See the file COPYING for details.
7from xpra.util import csv, parse_simple_dict
8from xpra.os_util import getuid, getgid
9from xpra.server.auth.sys_auth_base import SysAuthenticator, log
12class SQLAuthenticator(SysAuthenticator):
14 def __init__(self, username, **kwargs):
15 self.password_query = kwargs.pop("password_query", "SELECT password FROM users WHERE username=(%s)")
16 self.sessions_query = kwargs.pop("sessions_query",
17 "SELECT uid, gid, displays, env_options, session_options "+
18 "FROM users WHERE username=(%s) AND password=(%s)")
19 super().__init__(username, **kwargs)
20 self.authenticate_check = self.authenticate_hmac
22 def db_cursor(self, *sqlargs):
23 raise NotImplementedError()
25 def get_passwords(self):
26 cursor = self.db_cursor(self.password_query, (self.username,))
27 data = cursor.fetchall()
28 if not data:
29 log.info("username '%s' not found in sqlauth database", self.username)
30 return None
31 return tuple(str(x[0]) for x in data)
33 def get_sessions(self):
34 cursor = self.db_cursor(self.sessions_query, (self.username, self.password_used or ""))
35 data = cursor.fetchone()
36 if not data:
37 return None
38 return self.parse_session_data(data)
40 def parse_session_data(self, data):
41 try:
42 uid = data[0]
43 gid = data[1]
44 displays = []
45 env_options = {}
46 session_options = {}
47 if len(data)>2:
48 displays = [x.strip() for x in str(data[2]).split(",")]
49 if len(data)>3:
50 env_options = parse_simple_dict(str(data[3]), ";")
51 if len(data)>4:
52 session_options = parse_simple_dict(str(data[4]), ";")
53 except Exception as e:
54 log("parse_session_data() error on row %s", data, exc_info=True)
55 log.error("Error: sqlauth database row parsing problem:")
56 log.error(" %s", e)
57 return None
58 return uid, gid, displays, env_options, session_options
61class DatabaseUtilBase:
63 def __init__(self, uri):
64 self.uri = uri
65 self.param = "?"
67 def exec_database_sql_script(self, cursor_cb, *sqlargs):
68 raise NotImplementedError()
70 def create(self):
71 sql = ("CREATE TABLE users ("
72 "username VARCHAR(255) NOT NULL, "
73 "password VARCHAR(255), "
74 "uid VARCHAR(63), "
75 "gid VARCHAR(63), "
76 "displays VARCHAR(8191), "
77 "env_options VARCHAR(8191), "
78 "session_options VARCHAR(8191))")
79 self.exec_database_sql_script(None, sql)
81 def add_user(self, username, password, uid=getuid(), gid=getgid(),
82 displays="", env_options="", session_options=""):
83 sql = "INSERT INTO users(username, password, uid, gid, displays, env_options, session_options) "+\
84 "VALUES(%s, %s, %s, %s, %s, %s, %s)" % ((self.param,)*7)
85 self.exec_database_sql_script(None, sql,
86 (username, password, uid, gid, displays, env_options, session_options))
88 def remove_user(self, username, password=None):
89 sql = "DELETE FROM users WHERE username=%s" % self.param
90 sqlargs = (username, )
91 if password:
92 sql += " AND password=%s" % self.param
93 sqlargs = (username, password)
94 self.exec_database_sql_script(None, sql, sqlargs)
96 def list_users(self):
97 fields = ("username", "password", "uid", "gid", "displays", "env_options", "session_options")
98 def fmt(values, sizes):
99 s = ""
100 for i, field in enumerate(values):
101 if i==0:
102 s += "|"
103 s += ("%s" % field).rjust(sizes[i])+"|"
104 return s
105 def cursor_callback(cursor):
106 rows = cursor.fetchall()
107 if not rows:
108 print("no rows found")
109 cursor.close()
110 return
111 print("%i rows found:" % len(rows))
112 #calculate max size for each field:
113 sizes = [len(x)+1 for x in fields]
114 for row in rows:
115 for i, value in enumerate(row):
116 sizes[i] = max(sizes[i], len(str(value))+1)
117 total = sum(sizes)+len(fields)+1
118 print("-"*total)
119 print(fmt((field.replace("_", " ") for field in fields), sizes))
120 print("-"*total)
121 for row in rows:
122 print(fmt(row, sizes))
123 cursor.close()
124 sql = "SELECT %s FROM users" % csv(fields)
125 self.exec_database_sql_script(cursor_callback, sql)
127 def authenticate(self, username, password):
128 auth_class = self.get_authenticator_class()
129 a = auth_class(username, self.uri)
130 passwords = a.get_passwords()
131 assert passwords
132 log("authenticate: got %i passwords", len(passwords))
133 assert password in passwords
134 a.password_used = password
135 sessions = a.get_sessions()
136 assert sessions, "no sessions found"
137 log("sql authentication success, found sessions: %s", sessions)
139 def get_authenticator_class(self):
140 raise NotImplementedError()
143def run_dbutil(DatabaseUtilClass=DatabaseUtilBase, conn_str="databaseURI", argv=()):
144 def usage(msg="invalid number of arguments"):
145 print(msg)
146 print("usage:")
147 print(" %s %s create" % (argv[0], conn_str))
148 print(" %s %s list" % (argv[0], conn_str))
149 print(" %s %s add username password [uid, gid, displays, env_options, session_options]" % (argv[0], conn_str))
150 print(" %s %s remove username [password]" % (argv[0], conn_str))
151 print(" %s %s authenticate username password" % (argv[0], conn_str))
152 return 1
153 from xpra.platform import program_context
154 with program_context("SQL Auth", "SQL Auth"):
155 l = len(argv)
156 if l<3:
157 return usage()
158 uri = argv[1]
159 dbutil = DatabaseUtilClass(uri)
160 cmd = argv[2]
161 if cmd=="create":
162 if l!=3:
163 return usage()
164 dbutil.create()
165 elif cmd=="add":
166 if l<5 or l>10:
167 return usage()
168 dbutil.add_user(*argv[3:])
169 elif cmd=="remove":
170 if l not in (4, 5):
171 return usage()
172 dbutil.remove_user(*argv[3:])
173 elif cmd=="list":
174 if l!=3:
175 return usage()
176 dbutil.list_users()
177 elif cmd=="authenticate":
178 if l!=5:
179 return usage()
180 dbutil.authenticate(*argv[3:])
181 else:
182 return usage("invalid command '%s'" % cmd)
183 return 0