"""mock utilities for testing Functions --------- - mock_authenticate - mock_check_account - mock_open_session Spawners -------- - MockSpawner: based on LocalProcessSpawner - SlowSpawner: - NeverSpawner: - BadSpawner: - SlowBadSpawner - FormSpawner Other components ---------------- - MockPAMAuthenticator - MockHub - MockSingleUserServer - InstrumentedSpawner - public_host - public_url """ import asyncio import os import sys from concurrent.futures import ThreadPoolExecutor from tempfile import NamedTemporaryFile from unittest import mock from urllib.parse import urlparse from pamela import PAMError from sqlalchemy import event from tornado.httputil import url_concat from traitlets import Bool, Dict, default from .. import metrics, orm, roles from ..app import JupyterHub from ..auth import PAMAuthenticator from ..spawner import SimpleLocalProcessSpawner from ..utils import random_port, url_path_join, utcnow from .utils import AsyncSession, public_url, ssl_setup def mock_authenticate(username, password, service, encoding): # just use equality for testing if password == username: return True else: raise PAMError("Fake") def mock_check_account(username, service, encoding): if username.startswith('notallowed'): raise PAMError("Fake") else: return True def mock_open_session(username, service, encoding): pass class MockSpawner(SimpleLocalProcessSpawner): """Base mock spawner - disables user-switching that we need root permissions to do - spawns `jupyterhub.tests.mocksu` instead of a full single-user server """ def user_env(self, env): env = super().user_env(env) if self.handler: env['HANDLER_ARGS'] = self.handler.request.query return env @default('cmd') def _cmd_default(self): return [sys.executable, '-m', 'jupyterhub.tests.mocksu'] use_this_api_token = None def start(self): # preserve any JupyterHub env in mock spawner for key in os.environ: if 'JUPYTERHUB' in key and key not in self.env_keep: self.env_keep.append(key) if self.use_this_api_token: self.api_token = self.use_this_api_token elif self.will_resume: self.use_this_api_token = self.api_token return super().start() class SlowSpawner(MockSpawner): """A spawner that takes a few seconds to start""" delay = 5 _start_future = None async def start(self): (ip, port) = await super().start() if self._start_future is not None: await self._start_future else: await asyncio.sleep(self.delay) return ip, port async def stop(self): await asyncio.sleep(self.delay) await super().stop() class NeverSpawner(MockSpawner): """A spawner that will never start""" @default('start_timeout') def _start_timeout_default(self): return 1 def start(self): """Return a Future that will never finish""" return asyncio.Future() async def stop(self): pass async def poll(self): return 0 class BadSpawner(MockSpawner): """Spawner that fails immediately""" def start(self): raise RuntimeError("I don't work!") class SlowBadSpawner(MockSpawner): """Spawner that fails after a short delay""" async def start(self): await asyncio.sleep(0.5) raise RuntimeError("I don't work!") class FormSpawner(MockSpawner): """A spawner that has an options form defined""" options_form = "IMAFORM" def options_from_form(self, form_data): options = {'notspecified': 5} if 'bounds' in form_data: options['bounds'] = [int(i) for i in form_data['bounds']] if 'energy' in form_data: options['energy'] = form_data['energy'][0] if 'hello_file' in form_data: options['hello'] = form_data['hello_file'][0] if 'illegal_argument' in form_data: raise ValueError("You are not allowed to specify 'illegal_argument'") return options class FalsyCallableFormSpawner(FormSpawner): """A spawner that has a callable options form defined returning a falsy value""" options_form = lambda a, b: "" class MockStructGroup: """Mock grp.struct_group""" def __init__(self, name, members, gid=1111): self.gr_name = name self.gr_mem = members self.gr_gid = gid class MockStructPasswd: """Mock pwd.struct_passwd""" def __init__(self, name, gid=1111): self.pw_name = name self.pw_gid = gid class MockPAMAuthenticator(PAMAuthenticator): auth_state = None # If true, return admin users marked as admin. return_admin = False @default('admin_users') def _admin_users_default(self): return {'admin'} def system_user_exists(self, user): # skip the add-system-user bit return not user.name.startswith('dne') async def authenticate(self, *args, **kwargs): with mock.patch.multiple( 'pamela', authenticate=mock_authenticate, open_session=mock_open_session, close_session=mock_open_session, check_account=mock_check_account, ): username = await super().authenticate(*args, **kwargs) if username is None: return elif self.auth_state: return {'name': username, 'auth_state': self.auth_state} else: return username class MockHub(JupyterHub): """Hub with various mock bits""" # disable some inherited traits with hardcoded values db_file = None last_activity_interval = 2 log_datefmt = '%M:%S' @default('log_level') def _default_log_level(self): return 10 # MockHub additional traits external_certs = Dict() def __init__(self, *args, **kwargs): if 'internal_certs_location' in kwargs: cert_location = kwargs['internal_certs_location'] kwargs['external_certs'] = ssl_setup(cert_location, 'hub-ca') super().__init__(*args, **kwargs) if 'allow_all' not in self.config.Authenticator: self.config.Authenticator.allow_all = True @default('subdomain_host') def _subdomain_host_default(self): return os.environ.get('JUPYTERHUB_TEST_SUBDOMAIN_HOST', '') @default('bind_url') def _default_bind_url(self): if self.subdomain_host: port = urlparse(self.subdomain_host).port else: port = random_port() return 'http://127.0.0.1:%i/@/space%%20word/' % (port,) @default('ip') def _ip_default(self): return '127.0.0.1' @default('port') def _port_default(self): if self.subdomain_host: port = urlparse(self.subdomain_host).port if port: return port return random_port() @default('authenticator_class') def _authenticator_class_default(self): return MockPAMAuthenticator @default('spawner_class') def _spawner_class_default(self): return MockSpawner def init_signal(self): pass def load_config_file(self, *args, **kwargs): pass def init_tornado_application(self): """Instantiate the tornado Application object""" super().init_tornado_application() # reconnect tornado_settings so that mocks can update the real thing self.tornado_settings = self.users.settings = self.tornado_application.settings def init_services(self): # explicitly expire services before reinitializing # this only happens in tests because re-initialize # does not occur in a real instance for service in self.db.query(orm.Service): self.db.expire(service) return super().init_services() test_clean_db = Bool(True) def init_db(self): """Ensure we start with a clean user & role list""" super().init_db() if self.test_clean_db: for user in self.db.query(orm.User): self.db.delete(user) for group in self.db.query(orm.Group): self.db.delete(group) for role in self.db.query(orm.Role): self.db.delete(role) self.db.commit() # count db requests self.db_query_count = 0 engine = self.db.get_bind() @event.listens_for(engine, "before_execute") def before_execute(conn, clauseelement, multiparams, params, execution_options): self.db_query_count += 1 def init_logging(self): super().init_logging() # enable log propagation for pytest capture self.log.propagate = True # clear unneeded handlers self.log.handlers.clear() async def initialize(self, argv=None): self.pid_file = NamedTemporaryFile(delete=False).name self.db_file = NamedTemporaryFile() self.db_url = os.getenv('JUPYTERHUB_TEST_DB_URL') or self.db_file.name if 'mysql' in self.db_url: self.db_kwargs['connect_args'] = {'auth_plugin': 'mysql_native_password'} await super().initialize([]) # add an initial user user = self.db.query(orm.User).filter(orm.User.name == 'user').first() if user is None: user = orm.User(name='user') # avoid initial state inconsistency by setting initial activity user.last_activity = utcnow() self.db.add(user) self.db.commit() metrics.TOTAL_USERS.inc() roles.assign_default_roles(self.db, entity=user) self.db.commit() _stop_called = False def stop(self): if self._stop_called: return self._stop_called = True # run cleanup in a background thread # to avoid multiple eventloops in the same thread errors from asyncio def cleanup(): loop = asyncio.new_event_loop() loop.run_until_complete(self.cleanup()) loop.close() with ThreadPoolExecutor(1) as pool: f = pool.submit(cleanup) # wait for cleanup to finish f.result() # prevent redundant atexit from running self._atexit_ran = True super().stop() self.db_file.close() async def login_user(self, name): """Login a user by name, returning her cookies.""" base_url = public_url(self) s = AsyncSession() if self.internal_ssl: s.verify = self.external_certs['files']['ca'] login_url = base_url + 'hub/login' r = await s.get(login_url) r.raise_for_status() xsrf = r.cookies['_xsrf'] r = await s.post( url_concat(login_url, {"_xsrf": xsrf}), data={'username': name, 'password': name}, allow_redirects=False, ) r.raise_for_status() # make second request to get updated xsrf cookie r2 = await s.get( url_path_join(base_url, "hub/home"), allow_redirects=False, ) assert r2.status_code == 200 assert sorted(s.cookies.keys()) == [ '_xsrf', 'jupyterhub-hub-login', 'jupyterhub-session-id', ] return s.cookies class InstrumentedSpawner(MockSpawner): """ Spawner that starts a full singleuser server instrumented with the JupyterHub test extension. """ @default("default_url") def _default_url(self): """Use a default_url that any jupyter server will provide Should be: - authenticated, so we are testing auth - always available (i.e. in mocked ServerApp and NotebookApp) - *not* an API handler that raises 403 instead of redirecting """ return "/tree" @default('cmd') def _cmd_default(self): return [sys.executable, '-m', 'jupyterhub.singleuser'] def start(self): self.environment["JUPYTERHUB_SINGLEUSER_TEST_EXTENSION"] = "1" return super().start()