objax.privacy package

PrivateGradValues(f, vc, noise_multiplier, …)

Computes differentially private gradients as required by DP-SGD.

apply_dp_sgd_analysis(q, noise_multiplier, steps)

Compute and print results of DP-SGD analysis.

compute_rdp(q, noise_multiplier, steps, orders)

Compute RDP of the Sampled Gaussian Mechanism.

get_privacy_spent(orders, rdp[, target_eps, …])

Compute delta (or eps) for given eps (or delta) from RDP values.

class objax.privacy.PrivateGradValues(f, vc, noise_multiplier, l2_norm_clip, microbatch, batch_axis=(0, ), keygen=<objax.random.random.Generator object>)[source]

Computes differentially private gradients as required by DP-SGD. This module can be used in place of GradVals, and automatically makes the optimizer differentially private.

__init__(f, vc, noise_multiplier, l2_norm_clip, microbatch, batch_axis=(0, ), keygen=<objax.random.random.Generator object>)[source]

Constructs a PrivateGradValues instance.

Parameters
  • f (Callable) – the function for which to compute gradients.

  • vc (objax.variable.VarCollection) – the variables for which to compute gradients.

  • noise_multiplier (float) – scale of standard deviation for added noise in DP-SGD.

  • l2_norm_clip (float) – value of clipping norm for DP-SGD.

  • microbatch (int) – the size of each microbatch.

  • batch_axis (Tuple[Optional[int], ..]) – the axis to use as batch during vectorization. Should be a tuple of 0s.

  • keygen (objax.random.random.Generator) – a Generator for random numbers. Defaults to objax.random.DEFAULT_GENERATOR.

reshape_microbatch(x)[source]

Reshapes examples into microbatches. DP-SGD requires that per-example gradients are clipped and noised, however this can be inefficient. To speed this up, it is possible to clip and noise a microbatch of examples, at a sight cost to privacy. If speed is not an issue, the microbatch size should be set to 1.

If x has shape [D0, D1, …, Dn], the reshaped output will have shape [number_of_microbatches, microbatch_size, D1, …, DN].

Parameters

x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – items to be reshaped.

Returns

The reshaped items.

Return type

Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]

__call__(*args)[source]

Returns the computed DP-SGD gradients.

Returns

A tuple (gradients, value of f).

objax.privacy.apply_dp_sgd_analysis(q, noise_multiplier, steps, orders=1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 3.0, 3.5, 4.0, 4.5, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, delta=1e-05)[source]

Compute and print results of DP-SGD analysis.

Parameters
  • q (float) – The sampling rate.

  • noise_multiplier (float) – The ratio of the standard deviation of the Gaussian noise to the l2-sensitivity of the function to which it is added.

  • steps (int) – The number of steps.

  • orders (Tuple[float, ..]) – An array (or a scalar) of RDP orders.

  • delta (float) – The target delta.

Returns

eps

Raises

ValueError – If target_delta are messed up.

Return type

float

objax.privacy.compute_rdp(q, noise_multiplier, steps, orders)[source]

Compute RDP of the Sampled Gaussian Mechanism.

Parameters
  • q (float) – The sampling rate.

  • noise_multiplier (float) – The ratio of the standard deviation of the Gaussian noise to the l2-sensitivity of the function to which it is added.

  • steps (int) – The number of steps.

  • orders (Tuple[float, ..]) – An array (or a scalar) of RDP orders.

Returns

The RDPs at all orders, can be np.inf.

objax.privacy.get_privacy_spent(orders, rdp, target_eps=None, target_delta=None)[source]

Compute delta (or eps) for given eps (or delta) from RDP values.

Parameters
  • orders (Tuple[float, ..]) – An array (or a scalar) of RDP orders.

  • rdp (Tuple[float, ..]) – An array of RDP values. Must be of the same length as the orders list.

  • target_eps (float) – If not None, the epsilon for which we compute the corresponding delta.

  • target_delta (float) – If not None, the delta for which we compute the corresponding epsilon. Exactly one of target_eps and target_delta must be None.

Returns

eps, delta, opt_order.

Raises

ValueError – If target_eps and target_delta are messed up.

Return type

Tuple[float, float, float]