Primitive Roots

from hw3_t_util import mod_exp

def primitive_roots(n):
	"""Returns all the primitive_roots of n"""
	roots = []
	def is_primitive_root(r):
		s = set()
		for i in range(1, n):
			t = mod_exp(r, i, n)
			if t in s:
				return False
			s.add(t)
		return True
	for i in range(2, n):
		if is_primitive_root(i):
			roots.append(i)
	return roots

def test():
	assert primitive_roots(3) == [2]
	assert primitive_roots(5) == [2, 3]
	print "test pass"