summaryrefslogtreecommitdiff
path: root/frontend/afe/rpc_utils.py
blob: 14fa9877052a99ecbe55b3119923425f9a1dd20f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""\
Utility functions for rpc_interface.py.  We keep them in a separate file so that
only RPC interface functions go into that file.
"""

__author__ = 'showard@google.com (Steve Howard)'

import datetime, xmlrpclib, threading
from frontend.afe import models

def prepare_for_serialization(objects):
	"""
	Prepare Python objects to be returned via RPC.
	"""
	if (isinstance(objects, list) and len(objects) and
            isinstance(objects[0], dict) and 'id' in objects[0]):
		objects = gather_unique_dicts(objects)
	return _prepare_data(objects)


def _prepare_data(data):
	"""
	Recursively process data structures, performing necessary type
	conversions to values in data to allow for RPC serialization:
	-convert datetimes to strings
	-convert tuples to lists
	"""
	if isinstance(data, dict):
		new_data = {}
		for key, value in data.iteritems():
			new_data[key] = _prepare_data(value)
		return new_data
	elif isinstance(data, list) or isinstance(data, tuple):
		return [_prepare_data(item) for item in data]
	elif isinstance(data, datetime.datetime):
		return str(data)
	else:
		return data


def gather_unique_dicts(dict_iterable):
	"""\
	Pick out unique objects (by ID) from an iterable of object dicts.
	"""
	id_set = set()
	result = []
	for obj in dict_iterable:
		if obj['id'] not in id_set:
			id_set.add(obj['id'])
			result.append(obj)
	return result


def extra_job_filters(not_yet_run=False, running=False, finished=False):
	"""\
	Generate a SQL WHERE clause for job status filtering, and return it in
	a dict of keyword args to pass to query.extra().  No more than one of
	the parameters should be passed as True.
	"""
	assert not ((not_yet_run and running) or
		    (not_yet_run and finished) or
		    (running and finished)), ('Cannot specify more than one '
					      'filter to this function')
	if not_yet_run:
		where = ['id NOT IN (SELECT job_id FROM host_queue_entries '
			 'WHERE active OR complete)']
	elif running:
		where = ['(id IN (SELECT job_id FROM host_queue_entries '
			  'WHERE active OR complete)) AND '
			 '(id IN (SELECT job_id FROM host_queue_entries '
			  'WHERE not complete OR active))']
	elif finished:
		where = ['id NOT IN (SELECT job_id FROM host_queue_entries '
			 'WHERE not complete OR active)']
	else:
		return None
	return {'where': where}


def extra_host_filters(multiple_labels=[]):
	"""\
	Generate SQL WHERE clauses for matching hosts in an intersection of
	labels.
	"""
	extra_args = {}
	where_str = ('hosts.id in (select host_id from hosts_labels '
		     'where label_id=%s)')
	extra_args['where'] = [where_str] * len(multiple_labels)
	extra_args['params'] = [models.Label.smart_get(label).id
				for label in multiple_labels]
	return extra_args


local_vars = threading.local()

def set_user(user):
	"""\
	Sets the current request's logged-in user.  user should be a
	afe.models.User object.
	"""
        local_vars.user = user


def get_user():
	'Get the currently logged-in user as a afe.models.User object.'
        return local_vars.user


class InconsistencyException(Exception):
	'Raised when a list of objects does not have a consistent value'


def get_consistent_value(objects, field):
	value = getattr(objects[0], field)
	for obj in objects:
		this_value = getattr(obj, field)
		if this_value != value:
			raise InconsistencyException(objects[0], obj)
	return value


def prepare_generate_control_file(tests, kernel, label):
	test_objects = [models.Test.smart_get(test) for test in tests]
	# ensure tests are all the same type
	try:
		test_type = get_consistent_value(test_objects, 'test_type')
	except InconsistencyException, exc:
		test1, test2 = exc.args
		raise models.ValidationError(
		    {'tests' : 'You cannot run both server- and client-side '
		     'tests together (tests %s and %s differ' % (
		    test1.name, test2.name)})

	try:
		synch_type = get_consistent_value(test_objects, 'synch_type')
	except InconsistencyException, exc:
		test1, test2 = exc.args
		raise models.ValidationError(
		    {'tests' : 'You cannot run both synchronous and '
		     'asynchronous tests together (tests %s and %s differ)' % (
		    test1.name, test2.name)})

	is_server = (test_type == models.Test.Types.SERVER)
	is_synchronous = (synch_type == models.Test.SynchType.SYNCHRONOUS)
	if label:
		label = models.Label.smart_get(label)

	return is_server, is_synchronous, test_objects, label


def sorted(in_list):
	new_list = list(in_list)
	new_list.sort()
	return new_list