from hw3_t_util import mod_exp def primitive_roots(n): """Returns all the primitive_roots of n""" roots = [] def is_primitive_root(r): s = set() for i in range(1, n): t = mod_exp(r, i, n) if t in s: return False s.add(t) return True for i in range(2, n): if is_primitive_root(i): roots.append(i) return roots def test(): assert primitive_roots(3) == [2] assert primitive_roots(5) == [2, 3] print "test pass"