Jekyll2023-12-23T00:08:07+00:00https://ucsdml.github.io//feed.xmlUCSD Machine Learning GroupResearch updates from the UCSD community, with a focus on machine learning, data science, and applied algorithms.DYffusion: A Dynamics-informed Diffusion Model for Spatiotemporal Forecasting2023-12-22T00:00:00+00:002023-12-22T00:00:00+00:00https://ucsdml.github.io//2023/12/22/dyffusion<div class="l-body" align="center">
<img class="img-fluid rounded z-depth-1" src="/assets/img/2023-12-dyffusion/diagram.gif" />
<figcaption style="text-align: center; margin-top: 10px; margin-bottom: 10px;">
DYffusion forecasts a sequence of $h$ snapshots $\mathbf{x}_1, \mathbf{x}_2, \ldots, \mathbf{x}_h$
given the initial conditions $\mathbf{x}_0$ similarly to how standard diffusion models are used to sample from a distribution.</figcaption>
</div>
<h2 id="introduction">Introduction</h2>
<p>Obtaining <em>accurate and reliable probabilistic forecasts</em> has a wide range of applications from
climate simulations and fluid dynamics to financial markets and epidemiology.
Often, accurate <em>long-range</em> probabilistic forecasts are particularly challenging to obtain <d-cite key="300BillionServed2009, gneiting2005weather,bevacqua2023smiles"></d-cite>.
When they exist, physics-based methods typically hinge on computationally expensive
numerical simulations <d-cite key="bauer2015thequiet"></d-cite>.
In contrast, data-driven methods are much more efficient and have started to have real-world impact
in fields such as <a href="https://www.ecmwf.int/en/about/media-centre/news/2023/how-ai-models-are-transforming-weather-forecasting-showcase-data">global weather forecasting</a>.</p>
<p>Common approaches for large-scale spatiotemporal problems tend to be <em>deterministic</em> and <em>autoregressive</em>.
Thus, they are often unable to capture the inherent uncertainty in the data, produce unphysical predictions,
and are prone to error accumulation for long-range forecasts.</p>
<p>Diffusion models have shown great success for natural image and video generation.
However, diffusion models have been primarily designed for static data and are expensive to train and to sample from.
We study how we can <em>efficiently leverage them for large-scale spatiotemporal problems</em> and <em>explicitly
incorporate the temporality of the data into the diffusion model</em>.</p>
<h4 id="our-key-idea">Our Key Idea</h4>
<p>We introduce a solution for these issues by designing a temporal diffusion model, DYffusion.
Following the “generalized diffusion model” framework <d-cite key="bansal2022cold"></d-cite>, we
replace the forward and reverse processes of standard diffusion models
with dynamics-informed interpolation and forecasting, respectively.
This leads to a scalable generalized diffusion model for probabilistic forecasting that is naturally trained to forecast multiple timesteps.</p>
<h2 id="notation--background">Notation & Background</h2>
<h4 id="problem-setup">Problem setup</h4>
<p>We study the problem of probabilistic spatiotemporal forecasting using a dataset consisting of
a time series of snapshots \(\mathbf{x}_t \in \mathcal{X}\).
We focus on the task of forecasting a sequence of \(h\) snapshots from a single initial condition.
That is, we aim to train a model to learn \(P(\mathbf{x}_{t+1:t+h} \,|\, \mathbf{x}_t)\) .
Note that during evaluation, we may evaluate the model on a larger horizon \(H>h\) by running the model autoregressively.</p>
<h4 id="standard-diffusion-models">Standard diffusion models</h4>
<p>Diffusion models iteratively transform data between an initial distribution
and the target distribution over multiple diffusion steps<d-cite key="sohldickstein2015deepunsupervised, ho2020ddpm, karras2022edm"></d-cite>.
Here, we adapt the
<a src="https://lilianweng.github.io/posts/2021-07-11-diffusion-models/#forward-diffusion-process">common notation for diffusion models</a>
to use a superscript \(n\) for the diffusion states \(\mathbf{s}^{(n)}\),
to distinguish them from the timesteps of the data, \(\mathbf{x}_t\).
Given a data sample \(\mathbf{s}^{(0)}\), a standard diffusion model is defined through a <em>forward diffusion process</em>
\(q(\mathbf{s}^{(n)} \vert \mathbf{s}^{(n-1)})\)
in which small amounts of Gaussian noise are added to the sample in \(N\) steps, producing a sequence of noisy samples
\(\mathbf{s}^{(1)}, \ldots, \mathbf{s}^{(N)}\).
Adopting the notation for generalized diffusion models from <d-cite key="bansal2022cold"></d-cite>, we can also consider
a forward process operator, \(D\), that outputs the corrupted samples \(\mathbf{s}^{(n)} = D(\mathbf{s}^{(0)}, n)\).</p>
<div class="l-body">
<img class="img-fluid" src="/assets/img/2023-12-dyffusion/noise-diagram-gaussian.png" />
<figcaption style="text-align: center; margin-top: 10px; margin-bottom: 10px;"> Graphical model for a standard diffusion model.</figcaption>
</div>
<h2 id="dyffusion-dynamics-informed-diffusion-model">DYffusion: Dynamics-informed Diffusion Model</h2>
<p>The key innovation of our framework, DYffusion, is a reimagining of the diffusion processes to more naturally model
spatiotemporal sequences, \(\mathbf{x}_{t:t+h}\).
Specifically, we design the reverse (forward) process to step forward (backward) in time
so that our diffusion model emulates the temporal dynamics in
the data<d-footnote>Similarly to<d-cite key="song2021ddim, bansal2022cold"></d-cite>,
our forward and reverse processes cease to represent actual "diffusion" processes.
Differently to all prior work, our processes are _not_ based on data corruption or restoration.</d-footnote>.</p>
<div class="l-body">
<img class="img-fluid" src="/assets/img/2023-12-dyffusion/noise-diagram-dyffusion.png" />
<figcaption style="text-align: center; margin-top: 10px; margin-bottom: 10px">Graphical model for DYffusion. </figcaption>
</div>
<p>Implementation-wise, we replace the standard denoising network, \(R_\theta\), with a deterministic forecaster network, \(F_\theta\).
Because we do not have a closed-form expression for the forward process, we also need to learn it from data
by replacing the standard forward process operator, \(D\), with a stochastic interpolator network \(\mathcal{I}_\phi\).
Intermediate steps in DYffusion’s reverse process can be reused as forecasts for actual timesteps.
Another benefit of our approach is that the reverse process is initialized with the initial conditions of the dynamics
and operates in observation space at all times.
In contrast, a standard diffusion model is designed for unconditional generation, and reversing from white noise requires more diffusion steps.</p>
<h3 id="training-dyffusion">Training DYffusion</h3>
<p>We propose to learn the forward and reverse process in two separate stages:</p>
<h4 id="temporal-interpolation-as-a-forward-process">Temporal interpolation as a forward process</h4>
<p>To learn our proposed temporal forward process,
we train a time-conditioned network \(\mathcal{I}_\phi\) to interpolate between snapshots of data.
Given a horizon \(h\), we train the interpolator net so that
\(\mathcal{I}_\phi(\mathbf{x}_t, \mathbf{x}_{t+h}, i) \approx \mathbf{x}_{t+i}\) for \(i \in \{1, \ldots, h-1\}\) using the objective:</p>
\[\begin{equation}
\min_\phi
\mathbb{E}_{i \sim \mathcal{U}[\![1, h-1]\!], \mathbf{x}_{t, t+i, t+h} \sim \mathcal{X}}
\left[\|
\mathcal{I}_\phi(\mathbf{x}_t, \mathbf{x}_{t+h}, i) - \mathbf{x}_{t+i}
\|^2 \right].
\label{eq:interpolation}
\end{equation}\]
<p>Interpolation is an easier task than forecasting, and we can use the resulting interpolator
for temporal super-resolution during inference to interpolate beyond the temporal resolution of the data.
That is, the time input can be continuous, with \(i \in (0, h-1)\).
It is crucial for the interpolator, \(\mathcal{I}_\phi\),
to <em>produce stochastic outputs</em> within DYffusion so that its forward process is stochastic, and it can generate probabilistic forecasts at inference time.
We enable this using Monte Carlo dropout <d-cite key="gal2016dropout"></d-cite> at inference time.</p>
<h4 id="forecasting-as-a-reverse-process">Forecasting as a reverse process</h4>
<p>In the second stage, we train a forecaster network \(F_\theta\) to forecast \(\mathbf{x}_{t+h}\)
such that \(F_\theta(\mathcal{I}_\phi(\mathbf{x}_{t}, \mathbf{x}_{t+h}, i \vert \xi), i)\approx \mathbf{x}_{t+h}\)
for \(i \in S =[i_n]_{n=0}^{N-1}\), where \(S\) denotes a schedule coupling the diffusion step to the interpolation timestep.
The interpolator network, \(\mathcal{I}\), is frozen with inference stochasticity enabled,
represented by the random variable \(\xi\).
In our experiments, \(\xi\) stands for the randomly dropped out weights of the neural network and is omitted henceforth for clarity.
Specifically, we seek to optimize the objective</p>
\[\begin{equation}
\min_\theta
\mathbb{E}_{n \sim \mathcal{U}[\![0, N-1]\!], \mathbf{x}_{t, t+h}\sim \mathcal{X}}
\left[\|
F_\theta(\mathcal{I}_\phi(\mathbf{x}_{t}, \mathbf{x}_{t+h}, i_n \vert \xi), i_n) - \mathbf{x}_{t+h}
\|^2 \right].
\label{eq:forecaster}
\end{equation}\]
<p>To include the setting where \(F_\theta\) learns to forecast the initial conditions,
we define \(i_0 := 0\) and \(\mathcal{I}_\phi(\mathbf{x}_{t}, \cdot, i_0) := \mathbf{x}_t\).
In the simplest case, the forecaster net is supervised by all timesteps given
by the temporal resolution of the training data. That is, \(N=h\) and \(S = [j]_{j=0}^{h-1}\).
Generally, the schedule should satisfy \(0 = i_0 < i_n < i_m < h\) for \(0 < n < m \leq N-1\).</p>
<div class="l-body" align="center">
<img class="img-fluid rounded" src="/assets/img/2023-12-dyffusion/algo-training.png" width="75%" />
<figcaption style="text-align: center; margin-top: 10px; margin-bottom: 10px">DYffusion's two-stage training procedure is summarized in the algorithm above. </figcaption>
</div>
<h3 id="sampling-from-dyffusion">Sampling from DYffusion</h3>
<p>Our above design for the forward and reverse processes of DYffusion, implies the following generative process:
\(\begin{equation}
p_\theta(\mathbf{s}^{(n+1)} | \mathbf{s}^{(n)}, \mathbf{x}_t) =
\begin{cases}
F_\theta(\mathbf{s}^{(n)}, i_{n}) & \text{if} \ n = N-1 \\
\mathcal{I}_\phi(\mathbf{x}_t, F_\theta(\mathbf{s}^{(n)}, i_n), i_{n+1}) & \text{otherwise,}
\end{cases}
\label{eq:new-reverse}
\end{equation}\)</p>
<p>where \(\mathbf{s}^{(0)}=\mathbf{x}_t\) and \(\mathbf{s}^{(n)}\approx\mathbf{x}_{t+i_n}\)
correspond to the initial conditions and predictions of intermediate steps, respectively.
In our formulations, we reverse the diffusion step indexing to align with the temporal indexing of the data.
That is, \(n=0\) refers to the start of the reverse process,
while \(n=N\) refers to the final output of the reverse process with \(\mathbf{s}^{(N)}\approx\mathbf{x}_{t+h}\).
Our reverse process steps forward in time, in contrast to the mapping from noise to data in standard diffusion models.
As a result, DYffusion should require fewer diffusion steps and data.</p>
<p>DYffusion follows the generalized diffusion model framework.
Thus, we can use existing diffusion model sampling methods for inference.
In our experiments, we adapt the sampling algorithm from <d-cite key="bansal2022cold"></d-cite> to our setting as shown below.</p>
<div class="l-body" align="center">
<img class="img-fluid rounded" src="/assets/img/2023-12-dyffusion/algo-sampling-cold.png" width="75%" />
<figcaption style="text-align: center; margin-top: 10px; margin-bottom: 10px">Sampling algorithm for DYffusion. </figcaption>
</div>
<p>During the sampling process, our method essentially alternates between forecasting and interpolation,
as illustrated in the figure below.
\(R_\theta\) always predicts the last timestep, \(\mathbf{x}_{t+h}\),
but iteratively improves those forecasts as the reverse process comes closer in time to \(t+h\).
This is analogous to the iterative denoising of the “clean” data in standard diffusion models.
This motivates line 6 of Alg. 2, where the final forecast of \(\mathbf{x}_{t+h}\) can be used to
fine-tune intermediate predictions or to increase the temporal resolution of the forecast.</p>
<div class="l-body" align="center">
<img class="img-fluid rounded" src="/assets/img/2023-12-dyffusion/sampling-unrolled.png" width="75%" />
<figcaption style="text-align: center; margin-top: 10px; margin-bottom: 10px">
During sampling, DYffusion essentially alternates between forecasting and interpolation, following Alg. 2.
In this example, the sampling trajectory follows a simple schedule of going through all integer timesteps that precede the horizon of $h=4$,
with the number of diffusion steps $N=h$.
The output of the last diffusion step is used as the final forecast for $\hat\mathbf{x}_4$.
The <span style="color:black;font-weight:bold">black</span> lines represent forecasts by the forecaster network, $F_\theta$.
The first forecast is based on the initial conditions, $\mathbf{x}_0$.
The <span style="color:blue;font-weight:bold">blue</span> lines represent the subsequent temporal interpolations performed by the interpolator network, $\mathcal{I}_\phi$.
</figcaption>
</div>
<h3 id="memory-footprint">Memory footprint</h3>
<p>During training, DYffusion only requires \(\mathbf{x}_t\) and \(\mathbf{x}_{t+h}\) (plus \(\mathbf{x}_{t+i}\) during the first interpolation stage),
resulting in a <em>constant memory footprint as a function of</em> \(h\).
In contrast, direct multi-step prediction models including video diffusion models or (autoregressive) multi-step loss approaches require
\(\mathbf{x}_{t:t+h}\) to compute the loss.
This means that these models must fit \(h+1\) timesteps of data into memory (and may need to compute gradients recursively through them),
which scales poorly with the training horizon \(h\).
Therefore, many are limited to predicting a small number of frames or snapshots.
For example, our main video diffusion model baseline, MCVD, trains on a maximum of 5 video frames due to GPU memory constraints <d-cite key="voleti2022mcvd"></d-cite>.</p>
<div class="l-body" align="center">
<img class="img-fluid" src="/assets/img/2023-12-dyffusion/dyffusion-vs-video-diffusion-diagram.png" width="85%" />
<figcaption style="text-align: center; margin-top: 10px; margin-bottom: 10px">On the top row, we illustrate the direct application of a video diffusion model to dynamics forecasting for a horizon of $h=3$.
On the bottom row, DYffusion generates continuous-time probabilistic forecasts for $\mathbf{x}_{t+1:t+h}$, given the initial conditions, $\mathbf{x}_t$.
Our approach operates in the observation space at all times and does not need to model high-dimensional videos at each diffusion state.</figcaption>
</div>
<h2 id="experimental-setup">Experimental Setup</h2>
<h4 id="datasets">Datasets</h4>
<p>We evaluate our method and baselines on three different datasets:</p>
<ol>
<li><strong>Sea Surface Temperatures (SST):</strong> a new dataset based on NOAA OISSTv2<d-cite key="huang2021oisstv2"></d-cite>, which
comes at a daily time-scale. Similarly to <d-cite key="de2018physicalsstbaseline, wang2022metalearning"></d-cite>,
we train our models on regional patches which increases the available
data<d-footnote>Here, we choose 11 boxes of $60$ latitude $\times 60$ longitude resolution in the eastern tropical Pacific Ocean.
Unlike the data based on the NEMO dataset in <d-cite key="de2018physicalsstbaseline, wang2022metalearning"></d-cite>,
we choose OISSTv2 as our SST dataset because it contains more data (although it has a lower spatial resolution of $1/4^\circ$ compared to $1/12^\circ$ of NEMO).</d-footnote>.
We train, validate, and test all models for the years 1982-2019, 2020, and 2021, respectively.</li>
<li><strong>Navier-Stokes</strong> flow benchmark dataset from <d-cite key="otness21nnbenchmark"></d-cite>, which consists of a
\(221\times42\) grid. Each trajectory contains four randomly generated circular obstacles that block the flow.
The channels consist of the \(x\) and \(y\) velocities as well as a pressure field and the viscosity is \(1e\text{-}3\).
Boundary conditions and obstacle masks are given as additional inputs to all models.</li>
<li><strong>Spring Mesh</strong> benchmark dataset from <d-cite key="otness21nnbenchmark"></d-cite>. It represents a \(10\times10\) grid of
particles connected by springs, each with mass 1. The channels consist of two position and momentum fields each.</li>
</ol>
<p>We follow the official train, validation, and test splits from <d-cite key="otness21nnbenchmark"></d-cite> for the Navier-Stokes and spring mesh datasets,
always using the full training set for training.</p>
<h4 id="baselines">Baselines</h4>
<p>We compare our method against both direct applications of standard diffusion models to dynamics forecasting and
methods to ensemble the “barebone” backbone network of each dataset. The network operating in “barebone” form means
that there is no involvement of diffusion.
We use the following baselines:</p>
<ul>
<li><strong>DDPM</strong><d-cite key="ho2020ddpm"></d-cite>: We train it as a multi-step (video-like problem) conditional diffusion model.</li>
<li><strong>MCVD</strong><d-cite key="voleti2022mcvd"></d-cite>: A state-of-the-art conditional video diffusion model<d-footnote>We train MCVD in "concat" mode, which in their experiments performed best.</d-footnote>.</li>
<li><strong>Dropout</strong><d-cite key="gal2016dropout"></d-cite>: Ensemble multi-step forecasting of the barebone backbone network based on enabling dropout at inference time.</li>
<li><strong>Perturbation</strong><d-cite key="pathak2022fourcastnet"></d-cite>: Ensemble multi-step forecasting with the barebone backbone network based on random perturbations of the initial conditions with a fixed variance.</li>
<li>Official <strong>deterministic</strong> baselines from<d-cite key="otness21nnbenchmark"></d-cite> for
the Navier-Stokes and spring mesh datasets <d-footnote>Due to their deterministic nature, we exclude these baselines from our main probabilistic benchmarks.</d-footnote>.</li>
</ul>
<p>MCVD and the multi-step DDPM predict the timesteps \(\mathbf{x}_{t+1:t+h}\) based on \(\mathbf{x}_{t}\).
The barebone backbone network baselines are time-conditioned forecasters trained on the multi-step objective
\(\mathbb{E}_{i \sim \mathcal{U}[\![1, h]\!], \mathbf{x}_{t, t+i}\sim \mathcal{X}}
\| F_\theta(\mathbf{x}_{t}, i) - \mathbf{x}_{t+i}\|^2\)
from scratch<d-footnote>We found it to perform very similarly to predicting all $h$
horizon timesteps at once in a single forward pass, i.e. on the
objective $\mathbb{E}_{\mathbf{x}_{t:t+h}\sim \mathcal{X}} \| F_\theta(\mathbf{x}_{t}) - \mathbf{x}_{t+1:t+h}\|^2$</d-footnote>.</p>
<h4 id="neural-network-architectures">Neural network architectures</h4>
<p>For a given dataset, we use the <em>same backbone architecture</em> for all baselines as well as for both the interpolation and forecaster networks in DYffusion.
For the SST dataset, we use a <a href="https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py">popular UNet architecture</a> designed for diffusion models.
For the Navier-Stokes and spring mesh datasets, we use the UNet and CNN from the original benchmark paper <d-cite key="otness21nnbenchmark"></d-cite>, respectively.
The UNet and CNN models from <d-cite key="otness21nnbenchmark"></d-cite> are extended by the sine/cosine-based featurization module of the SST UNet to embed the diffusion step or dynamical timestep.</p>
<h4 id="evaluation-metrics">Evaluation metrics</h4>
<p>We evaluate the models by generating an M-member ensemble (i.e. M samples are drawn per batch element), where
we use M=20 for validation and M=50 for testing.
As metrics, we use the Continuous Ranked Probability Score (CRPS) <d-cite key="matheson1976crps"></d-cite>,
the mean squared error (MSE), and the spread-skill ratio (SSR).
The CRPS is a proper scoring rule and a popular metric in the probabilistic forecasting
literature<d-cite key="gneiting2014Probabilistic, bezenac2020normalizing, Rasul2021AutoregressiveDD, rasp2018postprocessing, scher2021ensemble"></d-cite>.
The MSE is computed on the ensemble mean prediction.
The SSR is defined as the ratio of the square root of the ensemble variance to the corresponding ensemble mean RMSE.
It serves as a measure of the reliability of the ensemble, where values smaller than 1 indicate
underdispersion<d-footnote>That is, the probabilistic forecast is overconfident and fails to model the full uncertainty of the forecast</d-footnote>
and larger values overdispersion<d-cite key="fortin2014ssr, garg2022weatherbenchprob"></d-cite>.
On the Navier-Stokes and spring mesh datasets, models are evaluated by autogressively forecasting the full test trajectories of length 64 and 804, respectively.
For the SST dataset, all models are evaluated on forecasts of up to 7 days<d-footnote>We do not explore more long-term SST forecasts because the chaotic nature of the system, and the fact that we only use regional patches, inherently limits predictability.</d-footnote>.</p>
<h2 id="results">Results</h2>
<h3 id="quantitative-results">Quantitative results</h3>
<p>We present the time-averaged metrics for the SST and Navier-Stokes dataset in the table below.
DYffusion performs best on the Navier-Stokes dataset, while coming in a close second on the SST dataset after MCVD, in terms of CRPS.
Since MCVD uses 1000 diffusion steps,
it is slower to sample from at inference time than from DYffusion, which is trained with at most 35 diffusion steps.
The DDPM model for the SST dataset is fairly efficient because it only uses 5 diffusion steps but lags in terms of performance.</p>
<div class="l-body" align="center">
<img class="img-fluid rounded" src="/assets/img/2023-12-dyffusion/results-table-main.png" width="95%" />
<figcaption style="text-align: center; margin-top: 10px; margin-bottom: 10px">
Results for sea surface temperature forecasting of 1 to 7 days ahead, and Navier-Stokes
flow full trajectory forecasting of 64 timesteps.
For SST, all models are trained on forecasting $h=7$ timesteps. The time column represents the time needed to forecast all 7 timesteps for a single batch.
For Navier-Stokes, Perturbation, Dropout, and DYffusion are trained on a horizon of $h=16$.
MCVD and DDPM are trained on $h=4$ and $h=1$, respectively, as we could not successfully train them using larger horizons.
<span style="font-weight:bold">Bold</span> indicates best, <span style="color:blue">blue</span> second best.
For CRPS and MSE, lower is better. For SSR, closer to 1 is better. Numbers are averaged out over the evaluation horizon.
</figcaption>
</div>
<p>Thanks to the dynamics-informed and memory-efficient nature of DYffusion, we can scale our framework to long horizons.
On the spring mesh dataset, we train with a horizon of 134 and evaluate the models on trajectories of 804 time steps.
Our method beats the Dropout baseline, with a larger margin on the out-of-distribution test dataset.
Despite several attempts with varying hyperparameter configurations neither the DDPM nor the MCVD diffusion model converged on this dataset.</p>
<div class="l-body" align="center">
<img class="img-fluid rounded" src="/assets/img/2023-12-dyffusion/results-table-spring-mesh.png" width="95%" />
<figcaption style="text-align: center; margin-top: 10px; margin-bottom: 10px">
Spring Mesh results. Both methods are trained on a horizon of $h = 134$ timesteps and
evaluated how well they forecast the full test trajectories of 804 steps.
For CRPS and MSE, lower is better. For SSR, closer to 1 is better. Numbers are averaged out over the evaluation horizon.
</figcaption>
</div>
<p>The reported MSE scores above, using the same CNN architecture,
are significantly better than the ones reported for the official CNN baselines in Fig. 8 of <d-cite key="otness21nnbenchmark"></d-cite>,
where the deterministic CNN diverged or attained a very poor MSE.
This is likely because our models are trained to forecast multiple timesteps,
while the models from <d-cite key="otness21nnbenchmark"></d-cite> are trained to forecast the next timestep only.
As a result, the training objective significantly deviates from the evaluation procedure,
which was already noted as a limitation of the benchmark baselines in <d-cite key="otness21nnbenchmark"></d-cite>.
This effect is also found for the Navier-Stokes dataset to a lower extent, as demonstrated in the figures below.</p>
<div class="row l-body">
<div class="col-sm">
<img class="img-fluid rounded" src="/assets/img/2023-12-dyffusion/mse-vs-time-navier-stokes.png" />
<figcaption style="text-align: center; margin-top: 10px; margin-bottom: 10px">Navier-Stokes</figcaption>
</div>
<div class="col-sm">
<img class="img-fluid rounded" src="/assets/img/2023-12-dyffusion/mse-vs-time-spring-mesh.png" />
<figcaption style="text-align: center; margin-top: 10px; margin-bottom: 10px">Spring Mesh</figcaption>
</div>
<!--- add joint caption here --->
<figcaption style="text-align: center; margin-top: 10px; margin-bottom: 10px">
Comparison against single-step deterministic baselines from <d-cite key="otness21nnbenchmark"></d-cite>.
We plot the MSE as a function of the rollout time step.
For spring mesh, we plot each of the three models trained with a different random seed separately due to the high variance.
</figcaption>
</div>
<h3 id="qualitative-results">Qualitative results</h3>
<p>Long-range forecasts of ML models often suffer from blurriness or might even diverge when using autoregressive models.
In the video below, we show a complete Navier-Stokes test trajectory forecasted by DYffusion and the best baseline, Dropout, as well as the corresponding ground truth.
Our method can reproduce the true dynamics over the full trajectory and does so better than the baseline,
especially for fine-scale patterns such as the tails of the flow after the right-most obstacle.</p>
<div class="l-body" align="center">
<video width="100%" controls="">
<source src="/assets/img/2023-12-dyffusion/ls9vw31m-kwy9mak6-5fps.mp4" type="video/mp4" />
</video>
<figcaption style="text-align: center; margin-top: 10px; margin-bottom: 10px">
Exemplary samples from DYffusion and the best baseline, Dropout, as well as the corresponding ground truth from a complete Navier-Stokes trajectory forecast.
</figcaption>
</div>
<h3 id="temporal-super-resolution-and-sample-variability">Temporal super-resolution and sample variability</h3>
<p>Motivated by the continuous-time nature of DYffusion, we aim to study in this experiment whether it is possible to forecast
skillfully beyond the resolution given by the data.
Here, we forecast the same Navier-Stokes trajectory shown in the video above but at \(8\times\) resolution.
That is, DYffusion forecasts 512 timesteps instead of 64 in total.
This behavior can be achieved by either changing the sampling trajectory \([i_n]_{n=0}^{N-1}\) or
by including additional output timesteps, \(J\), for the refinement step of line 6 in Alg. 2.
In the video below, we choose to do the latter and find the 5 sampled forecasts to be visibly pleasing and temporally consistent with the ground truth.</p>
<div class="l-body" align="center">
<video width="100%" controls="">
<source src="/assets/img/2023-12-dyffusion/yivdhhzu-trajectory4-0.125res-timeDependentTruthBoundary-5samples-5fps.mp4" type="video/mp4" />
</video>
<figcaption style="text-align: center; margin-top: 10px; margin-bottom: 10px">
$8\times$ temporal super-resolution of a Navier-Stokes trajectory with DYffusion.
The ground truth is frozen in-between the original timesteps. Five distinct samples are shown.
</figcaption>
</div>
<p>Note that we hope that our probabilistic forecasting model can capture any of the possible,
uncertain futures instead of forecasting their mean, as a deterministic model would do.
As a result, some long-term rollout samples are expected to deviate from the ground truth.
For example, see the velocity at <em>t</em>=3.70 in the video above.
It is reassuring that DYffusion’s samples show sufficient variation, but also cover the ground truth quite well (sample 1).
This advantage is also reflected quantitatively in the spread-skill ratio (SSR) metric, where DYffusion
consistently reached values close to 1.</p>
<h3 id="iterative-refinement-of-forecasts">Iterative refinement of forecasts</h3>
<p>DYffusion’s forecaster network repeatedly predicts the same timestep, \(t+h\), during sampling.
Thus, we need to verify that these forecasts,
\(\hat{\mathbf{x}}_{t+h} = F_\theta(\mathbf{x}_{t+i_n}, i_n)\), tend to improve throughout the course of the reverse process,
i.e. as \(n\rightarrow N\) and \(\mathbf{x}_{t+i_n}\rightarrow\mathbf{x}_{t+h}\).
Below we show that this is indeed the case for the Navier-Stokes dataset.
Generally, we find that this observation tends to hold especially for the probabilistic metrics, CRPS and SSR,
while the trend is less clear for the MSE across all datasets (see Fig. 7 of <a href="https://arxiv.org/abs/2306.01984">our paper</a>).</p>
<div class="l-body" align="center">
<img class="img-fluid rounded" src="/assets/img/2023-12-dyffusion/diffusion-step-vs-metric-navier-stokes.png" width="100%" />
<figcaption style="text-align: center; margin-top: 10px; margin-bottom: 10px">
DYffusion's forecaster network iteratively improves its forecasts during sampling.
</figcaption>
</div>
<h2 id="conclusion">Conclusion</h2>
<p>DYffusion is the first diffusion model that relies on task-informed forward and reverse processes.
Other existing diffusion models, albeit more general, use data corruption-based processes.
Thus, our work provides a new perspective on designing a capable diffusion model,
and we hope that it will lead to a whole family of task-informed diffusion models.</p>
<p>If you have any application that you think could benefit from DYffusion, or build on top of it, we would love to hear from you!</p>
<p>For more details, please <strong><em>check out our <a href="https://arxiv.org/abs/2306.01984">NeurIPS 2023 paper</a>,
and our <a href="https://github.com/Rose-STL-Lab/dyffusion">code on GitHub</a></em></strong>.</p>Salva Rühling CachayDYffusion forecasts a sequence of $h$ snapshots $\mathbf{x}_1, \mathbf{x}_2, \ldots, \mathbf{x}_h$ given the initial conditions $\mathbf{x}_0$ similarly to how standard diffusion models are used to sample from a distribution.How to Actively Learn in Bounded Memory2021-11-21T17:00:00+00:002021-11-21T17:00:00+00:00https://ucsdml.github.io//jekyll/update/2021/11/21/al-memory<h3 id="a-brief-introduction-enriched-queries-and-memory-constraints">A Brief Introduction: Enriched Queries and Memory Constraints</h3>
<p>In the world of big-data, machine learning practice is dominated by massive supervised algorithms, techniques that require huge troves of labeled data to reach state of the art accuracy. While certainly successful in their own right, these methods <a href="https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7104701/">break down in important scenarios like disease classification</a> where labeling is expensive, and accuracy can be the difference between life and death. <a href="https://ucsdml.github.io/jekyll/update/2020/07/27/rel-comp.html">In a previous post</a>, we discussed a new technique for tackling these high risk scenarios using <em>enriched queries</em>: informative questions beyond labels (e.g., <em>comparing</em> data points). While the resulting algorithms use very few labeled data points and never make errors, their efficiency comes at a cost: <strong>memory usage</strong>.</p>
<p>For simplicity, in this post we’ll consider the following basic setup. Let $X$ be a set of $n$ labeled points, where the labeling is chosen from some underlying family of classifiers (e.g., linear classifiers). As the learner, we are given access to the (unlabeled) points in $X$, a <em>labeling oracle</em> we can call to learn the label of any particular $x \in X$, and a set of special <em>enriched oracles</em> that give further information about the underlying classifier (e.g., a <em>comparison oracle</em> which can compare any two points $x,x’ \in X$). Our goal is to learn the label of every point in $X$ in as few queries (calls to the oracle) as possible.</p>
<p><a href="https://arxiv.org/abs/1704.03564">Traditional techniques</a> for solving this problem aim to use only $\log(n)$ adaptive queries. For instance if $X$ is a set of points on the real line and the labeling is promised to come from some threshold, we can achieve this using just a labeling oracle and binary search. This gives an exponential improvement over the naive algorithm of requesting the label of every point! However, these strategies generally have a problem: in order to choose the most informative queries, they allow the algorithm access to all of $X$, implicitly assuming the entire dataset is stored in memory. Since we frequently deal with massive datasets in practice, this strategy quickly becomes intractable. In this post, we’ll discuss a new compression-based characterization of when its possible to learn in $\log(n)$ queries, but store only a <strong>constant</strong> number of points in the process.</p>
<h3 id="a-basic-example-learning-thresholds-via-compression">A Basic Example: Learning Thresholds via Compression</h3>
<p>Learning in constant memory may seem a tall order when the algorithm is already required to correctly recover every label in a size $n$ set $X$ in only $\log(n)$ queries. To convince the reader such a feat is even possible, let’s start with a fundamental example using only label queries: thresholds in 1D. Let $X$ be any set of $n$ points on $\mathbb{R}$ with (hidden) labels given by some threshold. We’d like to learn the label of every point in $X$ in around $\log(n)$ adaptive queries of the form “what is the label of $x \in X$?” Notice that to do this, it is enough to find the points directly to the right and left of the threshold—the only issue is we don’t know where they are! Classically, we’d try find these points using binary search. This would acheive the $\log(n)$ bound on queries, but determining which point to query in each step requires too much memory.</p>
<p>A better strategy for this problem was proposed by <a href="https://arxiv.org/abs/1704.03564">Kane, Lovett, Moran, and Zhang</a> (KLMZ). They follow a simple four step process:</p>
<ol>
<li>Randomly sample $O(1)$ points from remaining set (initially $X$ itself).</li>
<li>Query the labels of these points, and store them in memory.</li>
<li>Restrict to the set of points whose labels remain unknown.</li>
<li>Repeat $O(\log(n))$ times.</li>
</ol>
<p>Note that it is possible to remove points we have not queried in Step 3 (we call such points “inferred,” see Figure 1(c)). Indeed, KLMZ prove that despite only making $O(1)$ queries, each round should remove about half of the remaining points. As a result, after about $\log(n)$ rounds, we must have found the two points on either side of the threshold, and can therefore label all of $X$ as desired (see <a href="https://ucsdml.github.io/jekyll/update/2020/07/27/rel-comp.html">our previous post</a> for more details on this algorithm). This algorithm is much better than binary search, but it still stores $O(\log(n))$ points overall—we’d like an algorithm whose memory doesn’t scale with $n$ at all!</p>
<p>It turns out that for the class of thresholds, this can be achieved by a very simple tactic: in each round, only store the two points closest to each side of the threshold. This “compressed” version of the sample actually retains all relevant information, so the algorithm’s learning guarantees are completely unaffected. Let’s take a look pictorially.</p>
<p style="text-align: center;"><img src="/assets/2021-11-21-al-memory/threshold.png" width="90%" /></p>
<p>Since we can compress our storage down to a constant size in every round and never draw more than $O(1)$ points, this strategy results in a learner whose memory has no dependence on $X$ at all: a zero-error, query efficient, bounded memory learner.</p>
<h3 id="a-general-framework-lossless-sample-compression">A General Framework: Lossless Sample Compression</h3>
<p>Our example for thresholds in 1D suggests the following paradigm: if we can compress samples down to $O(1)$ points without harming inference, bounded memory learning is possible. This is true, but not particularly useful: most classes beyond thresholds can’t even be actively learned (e.g., <a href="https://cseweb.ucsd.edu/~dasgupta/papers/greedy.pdf">halfspaces in $2D$</a>), much less in bounded memory. To build learners for classes beyond thresholds, we’ll need to generalize our idea of compression to the <em>enriched query</em> regime. In more detail, let $X$ be a set and $H$ a family of binary labelings of $X$. We consider classes $(X,H)$ with an additional query set $Q$. Formally, $Q$ consists of a set of oracles that contain information about the set $X$ based upon the structure of the underlying hypothesis $h \in H$. Our formal definition of these oracles is fairly broad (see <a href="https://arxiv.org/abs/2102.05047">our paper</a> for exact details), but they can be thought of simply as functions dependent on the underlying hypothesis $h \in H$ that give additional structural information about tuples in $X$. One standard example is the <em>comparison oracle</em> on halfspaces. Given a particular halfspace $\langle \cdot, v \rangle$, the learner may send a pair $x,x’$ to the comparison oracle to learn which example is closer to the decision boundary, or equivalently they recieve $\text{sign}(\langle x, v \rangle - \langle x’, v \rangle)$).</p>
<p>To generalize our compression-based strategy for thresholds to the enriched query setting, we also need to discuss a little bit of background on the theory of inference. Let $(X,H)$ be a hypothesis class with associated query set $Q$. Given a sample $S \subset X$ and query response $Q(S)$, denote by $H_{Q(S)}$ the set of hypotheses consistent with $Q(S)$ (also called the <em>version space</em>, this is the set of $h \in H$ such that $Q(S)$ is a valid response if $h$ is the true underlying classifier). We say that $Q(S)$ <em>infers</em> some $x \in X$ if all consistent classifiers label $x$ the same, that is if there exists $z \in$ {$0,1$} such that:
\[
\forall h \in H_{Q(S)}, h(x)=z.
\]
This allows us to label $x$ with 100% certainty, since the true underlying classifier must lie in $H_{Q(S)}$ by definition, and all such classifiers give the same label to $x$!</p>
<p>In the case of thresholds, our compression strategy relied on the fact that the two points closest to the boundary inferred the same amount of information as the original sample. We can extend this idea naturally to the enriched query regime as well.</p>
<div class="definition">
Let $X$ be a set and $H$ a family of binary classifiers on $X$. We say $(X,H)$ has a lossless compression scheme (LCS) $W$ of size $k$ with respect to a set of enriched queries $Q$ if for all subsets $S \subset X$ and all query responses $Q(S)$, there exists a subset $W = W(Q(S)) \subseteq S$ such that $|W| \leq k$, and any point in $X$ whose label is inferred by $Q(S)$ is also inferred by $Q(W)$.
</div>
<p>Recall our goal is to correctly label every point in $X$. Using lossless compression, we can now state our general algorithm for this process:</p>
<ol>
<li>Randomly sample $O(1)$ points from remaining set (initially $X$ itself).</li>
<li>Make all queries on these points, and store them in memory.</li>
<li>Compress memory via the lossless compression scheme.</li>
<li>Restrict to the set of points whose labels remain unknown.</li>
<li>Repeat $O(\log(n))$ times.</li>
</ol>
<p>In <a href="https://arxiv.org/abs/2102.05047">recent work</a> with <a href="https://cseweb.ucsd.edu/~dakane/">Daniel Kane</a>, <a href="https://cseweb.ucsd.edu/~slovett/home.html">Shachar Lovett</a>, and <a href="https://sites.google.com/view/michal-moshkovitz">Michal Moshkovitz</a>, we prove that this basic algorithm achieves zero-error, query optimal, bounded memory learning.</p>
<div class="theorem">
If $(X,H)$ has a size-$k$ LCS with respect to $Q$, then the above algorithm correctly labels all points in $X$ in
</div>
<p>\[
O_k(\log(n)) \text{ queries}
\]
and
\[
O_k(1) \text{ memory}.
\]</p>
<p>Before moving on to some examples, let’s take a brief moment to discuss the proof. The result essentially follows in two steps. First, we’d like to show that for any distribution over $X$, drawing $O(k)$ points is sufficient to infer $1/2$ of $X$ in expectation. This follows similarly to standard results in the literature—one can either use the classic sample compression arguments of <a href="https://link.springer.com/content/pdf/10.1023/A:1022660318680.pdf">Floyd and Warmuth</a>, or more recent symmetry arguments of KLMZ. With this in hand, it’s easy to see that after $\log(n)$ rounds (learning $1/2$ of $X$ each round), we’ll have learned all of $X$. The second step is then to observe that our compression in each step has no effect on this learning procedure. This follows without too much difficulty from the definition of lossless sample compression, which promises that the compressed sub-sample preserves all such information.</p>
<h3 id="example-axis-aligned-rectangles">Example: Axis-Aligned Rectangles</h3>
<p>While interesting in its own right, a sufficient condition like Lossless Sample Compression is most useful if it applies to natural classifiers. We’ll finish our post by discussing an application of this paradigm to labeling a dataset $X$ when the underlying classifier is given by an <em>axis-aligned rectangle</em>. Axis-aligned Rectangles are a natural generalization of intervals to higher dimensions. They are given by a <em>product</em> of $d$ intervals in $\mathbb{R}$:
\[
R = \prod\limits_{i=1}^d [a_i,b_i],
\]
such that an example $x=(x_1,\ldots,x_d) \in \mathbb{R}^d$ lies in the rectangle if every feature lies inside the specified interval, that is $x_i \in [a_i,b_i]$.</p>
<p style="text-align: center;"><img src="/assets/2021-11-21-al-memory/rectangle.png" width="40%" /></p>
<p><a href="https://cseweb.ucsd.edu/~dasgupta/papers/greedy.pdf">Standard arguments</a> show that with only labels, learning the labels of a set $X$ of size $n$ takes $\Omega(n)$ queries in the worst case when the labeling is given by some underlying rectangle. To see why, let’s consider the simple case of 1D—intervals. The key observation is that a sample of points $S_{\text{out}}$ lying outside the interval cannot infer any information beyond its own labels. This is because for any $x \in \mathbb{R} \setminus S_{\text{out}}$, there exists an interval that includes $x$ but not $S_{\text{out}}$ (say $I=[x-\varepsilon,x+\varepsilon]$ for some small enough $\varepsilon$), and an interval that excludes $x$ and $S_{\text{out}}$ (say $I=[x+\varepsilon,x+2\varepsilon]$). As a result, we cannot tell whether $x$ is included in the underlying interval. In turn, this means that if we try to compress $S_{\text{out}}$ in any way, we will always lose information about the original sample.</p>
<p>To circumvent this issue, we introduce <strong>“odd-one-out” queries</strong>. This new query type allows the learner to take any point $x\in X$ in the dataset that lies outside of the rectangle $R$, and ask for a violated coordinate (i.e. a feature lying outside one of the specified intervals) and the direction of violation (was the coordinate too large, or too small?). Concretely, imagine a chef is trying to cook a dish for a particularly picky patron. After each failed attempt, the chef asks the patron what went wrong, and the patron responds with some feature they dislike (perhaps the meat was overcooked, or undersalted). It turns out that such scenarios have small lossless compression schemes (and are therefore learnable in bounded memory).</p>
<div class="theorem">
Axis-Aligned Rectangles over $\mathbb{R}^d$ have an $O(d)$-size LCS with respect to label and odd-one-out queries.
</div>
<p>We’ll wrap up our post by sketching the proof. It will be convenient to break our compression scheme into two parts: a scheme for points inside the rectangle, and a scheme points outside the rectangle.<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup></p>
<p>Let’s start with the former case and restrict our attention to a sample of points $S_{\text{in}}$ that lies entirely inside the rectangle. We claim that all the relevant information in this case is captured by the maximum and minimum values of coordinates in $S_{\text{in}}$. Storing the $2d$ points achieving these values can be viewed as storing a <strong>bounding box</strong> that is guaranteed to lie inside the underlying rectangle classifier.</p>
<p style="text-align: center;"><img src="/assets/2021-11-21-al-memory/inside.png" width="90%" /></p>
<p>Notice that for any point $x \in \mathbb{R}^d$ outside of the bounding box, the version space (that is the set of all rectangles that contain $S_{\text{in}}$) has both a rectangle that contains $x$, and a rectangle that does not contain $x$. This means that label queries on $S_{\text{in}}$ cannot infer any point outside of the bounding box. Since every point inside the box is inferred by the compressed sample, these $2d$ points give a compression set for $S_{\text{in}}$.</p>
<p>Now let’s restrict our attention to a sample $S_{\text{out}}$ that lies entirely outside the rectangle. In this case, we’ll additionally have to compress information given by the odd-one-out oracle as well as labels. Nevertheless, we claim that a simple strategy suffices: store the closest point to each edge of the rectangle.</p>
<p style="text-align: center;"><img src="/assets/2021-11-21-al-memory/outside.png" width="90%" /></p>
<p>In particular, because the odd-one-out oracle gives a violated coordinate and direction of violation, any point that is <em>further out</em> in the direction of violation must also lie outside the rectangle. In any given direction, it is not hard to see that all relevant information is captured by the closest point to the relevant edge, since any further point can be inferred to be too far in that direction.</p>
<h3 id="conclusion">Conclusion</h3>
<p>We’ve now seen that <em>lossless sample compression</em>, the ability to compress finite samples without loss of label inference, gives a simple algorithm for labeling an $n$-point dataset $X$ in $O(\log(n))$ queries while never storing more than $O(1)$ examples at a time. Furthermore, we’ve shown that lossless compression isn’t a hopelessly strong condition—basic real-world questions such as the odd-one-out query often lead to small compression schemes. In <a href="https://arxiv.org/abs/2102.05047">our recent paper</a> we give a few more examples of this phenomenon for richer classes such as decision trees and halfspaces in 2D.</p>
<p>On the other hand, there is still much left to explore! Lossless sample compression gives a sufficient condition for bounded memory active learning, but it is not clear if the condition is necessary. The parameter is closely related to a necessary condition for active learning called <em>inference dimension</em> (see <a href="https://ucsdml.github.io/jekyll/update/2020/07/27/rel-comp.html">our previous post</a> or <a href="https://arxiv.org/abs/1704.03564">KLMZ’s original paper</a> for a description), and it is an open problem whether these two measures are equivalent. A positive resolution would imply that every actively learnable class is also actively learnable in bounded memory! Finally, it is worth noting that the techniques we discuss in this post are not robust to noise. Building a general framework for the more realistic noise-tolerant regime remains an interesting open question as well.</p>
<h3 id="footnotes">Footnotes</h3>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:1" role="doc-endnote">
<p>Note that this does not immediately imply a compression set for general samples. However, the definition of lossless compression can be weakened to allow for seperate compression schemes for positive and negative examples without affecting the resulting implications on bounded memory learnability. <a href="#fnref:1" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div><a href='http://cseweb.ucsd.edu/~nmhopkin/'>Max Hopkins</a>Machine learning practice is dominated by massive supervised algorithms, but gathering sufficient data for these methods can often prove intractable. Active learning is an adaptive technique for annotating large datasets in exponentially fewer queries by finding the most informative examples. Prior works on (worst-case) active learning often require holding the entire dataset in memory, but this can also prove difficult for the desired use-case of big data! In this post, we cover recent work towards characterizing bounded memory active learning, opening the door to applications in settings (e.g., learning on mobile devices) where one can't necessarily hope to store all of your data at once.Understanding Instance-based Interpretability of Variational Auto-Encoders2021-10-15T17:00:00+00:002021-10-15T17:00:00+00:00https://ucsdml.github.io//jekyll/update/2021/10/15/interpretability-vae<h2 id="background-instance-based-interpretability">Background: Instance-based Interpretability</h2>
<p>Modern machine learning algorithms can achieve very high accuracy on many tasks such as image classification. Despite their great success, these algorithms are often <strong>black boxes</strong> as their predictions are mysterious to humans. For example, when we feed an image to a dog-versus-cat classifier, it says: “After a matrix product and max pooling and a non-linearity and a skip connection and another 100 math operations, look, the probability that ‘this image is a cat’ is 99%!” Unfortunately, this makes no sense to a human at all. To understand what is going on, we need information that can be easily interpreted by human. One way to provide more interpretable answer is to ask:</p>
<center><i>
Which training samples are most responsible for the prediction of a test sample?
</i></center>
<p><br /></p>
<p>This is called a <strong>counterfactual question</strong>. Instance-based interpretation methods answer this question by designing an interpretability score between every training sample and the test sample. High scores imply importance. Then, we can interpret the prediction by saying: the classifier labels the test image as a cat because these other training samples are cats, and they are most responsible for the prediction of the test image.</p>
<p>The notion of <a href="https://arxiv.org/abs/1703.04730">influence functions</a> is a popular instance-based interpretability method for supervised learning. The intuition is: if removing some $x$ in the training set results in a large difference of the prediction (such as the logits) of $z$, then $x$ is very important for the prediction of $z$. Imagine $z$ is a very special cat that is visually different from all training images except for one sample $x$. Then, $x$ has large influence over $z$ because removing $x$ probably leads to an incorrect prediction of $z$.</p>
<h2 id="interpretations-for-unsupervised-learning">Interpretations for Unsupervised Learning</h2>
<p>For supervised learning, instance-based interpretability methods reveal why a classifier makes a certain prediction. What about unsupervised learning? <a href="https://arxiv.org/abs/2105.14203">Our recent paper</a> investigates this problem for several unsupervised learning methods. The first challenge is, how do we frame the counterfactual question in unsupervised learning?</p>
<p>When the model fits a probability density to the training data, we ask: which training samples are most responsible for <strong>increasing the log-likelihood</strong> of a test sample? In deep generative models such as variational auto-encoders (<a href="https://arxiv.org/abs/1312.6114">VAE</a>), likelihood is not available. VAEs are optimized to maximize the <a href="https://en.wikipedia.org/wiki/Evidence_lower_bound">evidence lower bound</a> (ELBO) of the log-likelihood. Then, we ask: which training samples are most responsible for <strong>increasing the ELBO</strong> of a test sample?</p>
<p>Then, these questions can readily be answered by influence functions with proper loss functions. Formally, let $X=\{x_1,\cdots,x_n\}$ be the training set, and $\mathcal{A}$ be the unsupervised model. That is, $\mathcal{A}(X)$ returns the model fit to $X$. Let $L(X;\mathcal{A}) = \frac1N \sum_{i=1}^N \ell(x_i;\mathcal{A}(X))$ be the loss function, where the loss $\ell$ is negative log-likelihood in density estimators and negative ELBO in VAE. Then, the influence function of a training sample $x_i$ over a test sample $z$ is the difference of the losses at $z$ between models trained with and without $x_i$. Formally, we define the influence function as
\[\mathrm{IF}_{X,\mathcal{A}}(x_i,z) = \ell(z;\mathcal{A}(X\setminus\{x_i\})) - \ell(z;\mathcal{A}(X)).\]
We provide intuition for influence functions in the next section.</p>
<h2 id="what-should-influence-functions-tell-us">What Should Influence Functions Tell Us?</h2>
<p>What does it mean if $\mathrm{IF}(x_i,z)\gg0$? Straightforward, we have $\ell(z;\mathcal{A}(X\setminus\{x_i\})) \gg \ell(z;\mathcal{A}(X))$, which means removing $x_i$ should result in a large increase of the loss at $z$. In other words, $x_i$ is very important for the model $\mathcal{A}$ to learn $z$. Similarly, if $\mathrm{IF}(x_i,z)\ll0$, then $x_i$ negatively impacts the model in learning $z$; and if $\mathrm{IF}(x_i,z)\approx0$, then $x_i$ hardly impacts it.</p>
<p>For conciseness, we call training samples that have positive / negative influences over a test sample $z$ <strong>proponents</strong> / <strong>opponents</strong> of $z$. In supervised learning, strong proponents and opponents of $z$ are very important to explain the model’s prediction of $z$. Strong proponents help the model correctly predict the label of $z$ because they reduce the loss at $z$, while strong opponents harm it because they increase the loss at $z$. Empirically, strong proponents of $z$ are visually its similar samples from the same class, while strong opponents of $z$ are usually its dissimilar samples from the same class or its similar samples from a different class.</p>
<p>In unsupervised learning, we expect that strong proponents increase the likelihood of $z$ and strong opponents reduce it, so we ask:</p>
<center><i>
Which training samples are strong proponents and opponents of a test sample, respectively?
</i></center>
<p><br /></p>
<p>In particular, when we let $z = x_i$, we obtain a concept called <strong>self influence</strong>, or $\mathrm{IF}(x_i,x_i)$. This concept is very interesting in supervised learning because self influences provide rich information about memorization of training samples. For example, Feldman and Zhang study neural network memorization through the lens of self influences in <a href="https://arxiv.org/abs/2008.03703">this paper</a>. Intuitively, high self influence samples are atypical, ambiguous or mislabeled, while low self influence samples are typical. We want to know what self influences reveal in unsupervised learning, so we ask:</p>
<center><i>
Which training samples have the highest and lowest self influences, respectively?
</i></center>
<p><br /></p>
<p>By looking at these counterfactual questions, we hope to reveal what influence functions can tell us about (1) inductive biases of unsupervised learning models and (2) unrevealed properties of the training set (or distribution) such as outliers.</p>
<h2 id="intuitions-from-classical-unsupervised-learning">Intuitions from Classical Unsupervised Learning</h2>
<p>Let’s first look at these questions in the context of several classical unsupervised learning methods. The goal is to provide intuition on what influence functions should tell us in the unsupervised setting. Consider the following two-dimensional training data $X$ composed of six clusters.</p>
<p style="text-align: center;"><img src="/assets/2021-10-15-interpretability-vae/six_cluster_nocolor.png" width="30%" /></p>
<p>We consider three classical methods: the <a href="http://faculty.washington.edu/yenchic/18W_425/Lec7_knn_basis.pdf">$k$-nearest-neighbour</a> ($k$-NN) density estimator, the <a href="https://en.wikipedia.org/wiki/Kernel_density_estimation">kernel density estimator</a> (KDE), and <a href="https://en.wikipedia.org/wiki/Mixture_model">Gaussian mixture models</a> (GMM). We fit these models on $X$ and the probability densities of these models are shown below.</p>
<p style="text-align: center;"><img src="/assets/2021-10-15-interpretability-vae/density_knn.jpg" width="32%" />
<img src="/assets/2021-10-15-interpretability-vae/density_kde.jpg" width="32%" />
<img src="/assets/2021-10-15-interpretability-vae/density_gmm.jpg" width="32%" /></p>
<h1 id="self-influences">Self influences</h1>
<p>The figure below provides some insights of high and low self influence samples. The color of a point represents its self influence (red means high and blue means low).</p>
<ul>
<li>When using the $k$-NN density estimator, high self influence samples come from a cluster with exactly $k$ points.</li>
<li>When using the KDE density estimator, high self influence samples come from sparse regions, and low self influence samples come from dense regions.</li>
<li>When using the GMM density estimator, high self influence samples are far away to cluster centers, and low self influence samples are near cluster centers.</li>
</ul>
<p style="text-align: center;"><img src="/assets/2021-10-15-interpretability-vae/classic_selfinf.png" width="99%" /></p>
<p><br /></p>
<h1 id="proponents-and-opponents">Proponents and Opponents</h1>
<p>The figures below visualize an example of proponents and opponents. The test sample $z$ is marked as the green ✖︎ symbol, and the color of a point represents its influence over the test sample (red means proponents and blue means opponents). In all these models, strong proponents are the nearest neighbours of the test sample.</p>
<ul>
<li>When using the $k$-NN or the KDE density estimator, strong proponents of $z$ are exactly its $k$ nearest neighbours.</li>
<li>KDE seems to be the soft version of $k$-NN: influences over $z$ gradually decrease as distances to $z$ increase.</li>
<li>When using the GMM density estimator, it is surprising to observe that some strong opponents (blue points) of $z$ are from the same cluster! This phenomenon indicates that removing a sample from the same class can possibly increase the likelihood at $z$. To see why this happens, we note the GMM is parametric and has limited capacity. Therefore, training samples that are far away to the cluster centers can largely affect the mean and covariance matrices of the learned Gaussians.</li>
</ul>
<p><br /></p>
<p>Scatter plots of influences of all training samples:</p>
<p style="text-align: center;"><img src="/assets/2021-10-15-interpretability-vae/classic_testinf.png" width="99%" /></p>
<!-- <center>
<span style="font-size: 25px" > ↓ Zoomed in </span>
</center> -->
<p><br /></p>
<p>And the zoom in view that only shows the cluster which $z$ belongs to:</p>
<p style="text-align: center;"><img src="/assets/2021-10-15-interpretability-vae/classic_testinf_large.png" width="99%" /></p>
<p><br /></p>
<p>Note: please refer to Section 3 of <a href="https://arxiv.org/abs/2105.14203">our paper</a> for the closed-form influence functions.</p>
<p><br /></p>
<h2 id="variational-auto-encoders-vae">Variational Auto-Encoders (<a href="https://arxiv.org/abs/1312.6114">VAE</a>)</h2>
<p>Variational auto-encoders are a class of generative models composed of two networks: the encoder, which maps samples to latent vectors, and the decoder, which maps latent vectors to samples. These models are trained to maximize the <a href="https://en.wikipedia.org/wiki/Evidence_lower_bound">evidence lower bound</a> (ELBO), a lower bound of log-likelihood.</p>
<p>There are two challenges when we investigate influence functions in VAE.</p>
<ul>
<li>The influence function involves computing the loss at a test sample. However, the ELBO objective in VAE has an expectation term over the encoder, so it cannot be precisely computed.
<ul>
<li>Solution: we compute the empirical average of the influence function. We provide a probabilistic error bound on this estimate: if the empirical average is over $\mathcal{\Theta}\left(\frac{1}{\epsilon^2\delta}\right)$ samples, then with probability at least $1-\delta$, the error between the empirical and true influence functions is no more than $\epsilon$.</li>
</ul>
<p><br /></p>
</li>
<li>The influence function is hard to compute, as it requires inverting a large Hessian matrix. The number of rows in this matrix equals to the number of parameters in the VAE model, which can be as large as a million. Consequently, inverting this matrix (or even computing Hessian vector products) can be computationally infeasible.
<ul>
<li>Solution: we propose a computationally efficient algorithm called VAE-TracIn. It is based on the fast <a href="https://arxiv.org/abs/2002.08484">TracIn</a> algorithm, an approximation to influence functions. TracIn is efficient because (1) it only involves computing the first-order derivative of the loss, and (2) it can be accelerated with only a few checkpoints.</li>
</ul>
</li>
</ul>
<h1 id="a-sanity-check">A sanity check</h1>
<p>Does VAE-TracIn find the most influential training samples? In a good instance-based
interpretation, training samples should have large influences over themselves. Therefore, we design the following sanity check (which is analogous to the identical subclass test by Hanawa et al. in <a href="https://openreview.net/pdf/ede4daa61cd87856ebce2c047d94f9fdc6149edf.pdf">this reference</a>):</p>
<center><i>
Are training samples the strongest proponents over themselves?
</i></center>
<p><br /></p>
<p>The short answer is: yes. We visualize some training samples and their strongest proponents in the figures below. A sample is marked in a green box if it is exactly its strongest proponent, and in a red box otherwise. Quantitatively, almost all ($>99\%$) training samples are the strongest proponents of themselves, with only very few exceptions. And as shown, even if a samples is not its strongest proponent, it still ranks very high in the order of influence scores.</p>
<p style="text-align: center;"><img src="/assets/2021-10-15-interpretability-vae/sanity_mnist.png" width="30%" />
<img src="/assets/2021-10-15-interpretability-vae/sanity_cifarsub.png" width="30%" /></p>
<p><br /></p>
<h1 id="self-influences-for-vaes">Self influences for VAEs</h1>
<p>We visualize <strong>high</strong> self influence samples below. We find these samples are either hard to recognize or visually high-contrast.</p>
<p style="text-align: center;"><img src="/assets/2021-10-15-interpretability-vae/high_selfinf.png" width="80%" /></p>
<p><br /></p>
<p>We then visualize <strong>low</strong> self influence samples below. We find these samples share similar shapes or background.</p>
<p style="text-align: center;"><img src="/assets/2021-10-15-interpretability-vae/low_selfinf.png" width="80%" /></p>
<p><br /></p>
<p>These findings are consistent with the memorization analysis in the supervised setting by Feldman and Zhang in <a href="https://arxiv.org/abs/2008.03703">this reference</a>. Intuitively, high self influence samples are very different from most samples, so they must be memorized by the model. Low self influence samples, on the other hand, are very similar to each other, so the model does not need to memorize all of them. Quantitatively, we also find self influences correlate to the loss of training samples: generally, the larger loss, the larger self influence.</p>
<!-- The relationship is demonstrated in the scatter plot below: generally, the larger loss, the larger self influence.
{:refdef: style="text-align: center;"}
<img src="/assets/2021-10-15-interpretability-vae/selfinf_vs_loss.png" width="35%">
{:refdef}
<br /> -->
<p>The intuition on self influences leads to an application in unsupervised data cleaning. Because high self influence samples are visually complicated and different, they are likely to be outside the data manifold. Therefore, we can use self influences to detect unlikely (noisy, contaminated, or even incorrectly collected) samples. For example, they could be
unrecognizable handwritten digits or objects in MNIST or CIFAR. Similar approaches in supervised learning use self influences to detect mislabeled data or memorized samples.</p>
<h1 id="proponents-and-opponents-in-vaes">Proponents and Opponents in VAEs</h1>
<p>We visualize strong proponents and opponents of several test samples below.</p>
<p style="text-align: center;"><img src="/assets/2021-10-15-interpretability-vae/testinf.png" width="95%" /></p>
<p><br /></p>
<p>In MNIST, many strong proponents and opponents of test samples are similar samples from the same class. Especially, strong proponents look very similar to test samples, and strong opponents are visually slightly different. For example, the opponents of the test “two” have very different thickness and styles. Quantitatively, $\sim 80\%$ of the strongest proponents and $\sim 40\%$ of the strongest opponents have the same label as test samples. In addition, both of them have small latent space distance to the test samples. One can find this is very similar to GMM.</p>
<p>In CIFAR, we find strong proponents seem to match the color of the images – including the background and the object – and they tend to have the same but brighter colors. Strong opponents, on the other hand, tend to have very different colors as the test samples. Quantitatively, strong proponents have large norms in the latent space, indicating they are very likely to be outliers, high-contrast samples, or very bright samples. This observation is also validated in the visualizations. One can further connect this observation to influence functions in supervised learning. Hanawa et al. find extremely large norm samples are selected as relevant instances by influence functions in <a href="https://openreview.net/pdf?id=9uvhpyQwzM_">this reference</a>, and Barshan et al. find large norm samples can impact a large region in the data space when using the logistic regression in <a href="https://arxiv.org/abs/2003.11630">this reference</a>.</p>
<h2 id="open-questions">Open Questions</h2>
<p>There are many open questions based on our paper. Here is a list of some important future directions.</p>
<ul>
<li>How to design efficient instance-based interpretation methods for modern, large unsupervised learning models trained on millions of samples?</li>
<li>How can we use the instance-based interpretations to detect biases and fairness in models and data?</li>
<li>What are the other applications of instance-based interpretation methods?</li>
</ul>
<h2 id="more-details">More Details</h2>
<p>See <a href="https://arxiv.org/abs/2105.14203">the full paper on arxiv</a>.</p><a href='https://cseweb.ucsd.edu/~z4kong'>Zhifeng Kong</a> and <a href='http://cseweb.ucsd.edu/~kamalika'>Kamalika Chaudhuri</a>Instance-based interpretation methods, such as influence functions, have been widely studied for supervised learning methods as they help explain how black box neural networks predict. However, these interpretations remain ill-understood in the context of unsupervised learning. In this paper, we introduce VAE-TracIn, a computationally efficient and theoretically sound solution to investigate influence functions for variational auto-encoders (VAE). Our experiments reveal patterns about the impact of training data in VAE.Connecting Interpretability and Robustness in Decision Trees through Separation2021-09-24T17:00:00+00:002021-09-24T17:00:00+00:00https://ucsdml.github.io//jekyll/update/2021/09/24/interpretable-robust-trees<p><strong>TL;DR</strong> We construct a tree-based model that is <inc>guaranteed</inc>
to be adversarially robust, interpretable, and accurate.</p>
<p>Imagine a world where computers are fully integrated into
our everyday lives. Making decisions independently, without
human intervention. No need to worry about overly exhausted
doctors making life-changing decisions or driving your car
after a long day at the office. Sounds great, right? Well,
what if those computers weren’t reliable? What if a
computer decided you need to go through surgery without
telling you why? What if a car confused a child with a
green light? It doesn’t sound so great after all.</p>
<p>Before we fully embrace machine learning, it needs to be reliable.
The cornerstones for reliable machine learning are (i) interpretability,
where the model’s decisions are transparent, and (ii) robustness, where small changes
to the input do not change the model’s prediction.
Unfortunately, these properties are generally studied in isolation or only empirically.
Here, we explore interpretability and robustness <ins>simultaneously</ins>,
and examine it <ins>both theoretically and empirically</ins>.</p>
<!--In this post, our objective is to build a decision tree with guarantees
on its accuracy, robustness, and interpretability.-->
<p>We start this post by explaining what we mean by interpretability and robustness.
Next, to derive guarantees, we need some assumptions on the data.
We start with the known <a href="http://proceedings.mlr.press/v80/wang18c.html">$r$-separated data</a>.
We show that although there exists a tree that is accurate and robust,
such tree can be exponentially large, which makes it not interpretable.
To improve the guarantees, we make a stronger assumption on the data
and focus on linearly separable data.
We design an algorithm called BBM-RS and prove that it is accurate, robust, and interpretable on
linearly separable data.
Lastly, real datasets may not be linearly separable, so to understand how BBM-RS performs in practice,
we conduct an empirical study on $13$ datasets.
We find out that BBM-RS brings better robustness and interpretability while performing competitively
on test accuracy.</p>
<h2 id="what-do-we-mean-by-interpretability-and-robustness">What do we mean by interpretability and robustness?</h2>
<h3 id="interpretability">Interpretability</h3>
<p>A model is <strong>interpretable</strong> if the model is simple and self-explanatory.
There are several forms of
<a href="https://christophm.github.io/interpretable-ml-book/simple.html">self-explanatory models</a>,
e.g., <a href="https://www-cs-faculty.stanford.edu/people/jure/pubs/interpretable-kdd16.pdf">decision sets</a>,
<a href="https://en.wikipedia.org/wiki/Logistic_regression">logistic regression</a>, and
<a href="https://christophm.github.io/interpretable-ml-book/rules.html">decision rules</a>.
One of the most fundamental interpretable models, which we focus on here, are
<strong>small</strong> decision trees.
We use the size of a tree to determine whether it is interpretable or not.</p>
<h3 id="robustness">Robustness</h3>
<p>We also want our model to be robust to adversarial perturbations.
This means that if example $x$ is changed, by a bit, to $x’$, the model’s
answer remains the same.
By “a bit”, we mean that $x’=x+\delta$ where $\|\delta\|_\infty\leq r$ is
small. A model $h:\mathbf{X} \rightarrow \{-1, +1\}$ is <inc>robust</inc> at $x$ with radius
$r$ if for all such $x’$ we have that $h(x)=h(x’)$. The notion of <inc>astuteness</inc>
<a href="http://proceedings.mlr.press/v80/wang18c.html">was previously introduced</a>
to jointly measure the robustness and the accuracy of a model.
The astuteness of a model $h$ at radius $r > 0$ under a distribution $\mu$
is \[\Pr_{(x,y)\sim\mu}[h(x’)=y \ |\ \forall x’ \text{ with } \|x-x’\|_\infty\leq r].\]</p>
<h2 id="guarantees-under-different-data-assumptions">Guarantees under different data assumptions</h2>
<p>Without any assumptions on the data, we cannot guarantee
accuracy, interpretability, and robustness to hold simultaneously.
For example, if the true labels of the examples are different for close
examples, a model cannot be astute (accurate and robust).
In this section, we explore which data properties are sufficient for astuteness and interpretability.</p>
<h3 id="r-separation">$r$-Separation</h3>
<p><a href="http://proceedings.mlr.press/v80/wang18c.html">A prior work</a> suggested
focusing on datasets that satisfy $r$-separation.
A binary labeled data distribution is <inc>$r$-separated </inc> if every two differently labeled examples, $(x^1,+1)$,$(x^2,-1)$, are far apart,
$\|x^1-x^2\|_\infty\geq 2r.$
<a href="https://arxiv.org/abs/2003.02460">Yang et al.</a> showed that
$r$-separation is sufficient for robust learning.
Therefore, we examine whether it is also sufficient for accuracy and
interpretability.
We have two main findings.
First, we found that there is a accurate decision tree with size
independent of the number of examples.
Second, we discovered that the size of the accurate tree can be exponential
in the number of features.
Combining these two findings, it appears we need to find a stronger assumption on the data to
be able to have guarantees on both accuracy and interpretability.</p>
<h3 id="linear-separation">Linear separation</h3>
<p>Next, we investigate a stronger assumption — linear separation with a
$\gamma$-margin.
Intuitively, it means that a hyperplane separates the two labels in the data,
and the margin (distance of the closest point to the hyperplane) is at
least $\gamma$ (larger $\gamma$ means larger margin for the classifier).
More formally, there exists a vector $w$ with $\|w\|_1=1$ such that
for each training example and its label $(x, y)$, we have
$ywx\geq \gamma$.
Linear separation is a popular assumption in the research of
machine learning models, e.g., for
<a href="https://en.wikipedia.org/wiki/Support-vector_machine">support vector machines</a>,
<a href="https://arxiv.org/abs/1705.08292">neural networks</a>,
and <a href="https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.129.6343&rep=rep1&type=pdf">decision trees</a>.</p>
<p>Using a generalization of
<a href="https://www.cs.huji.ac.il/~shais/papers/ShalevSi08.pdf">previous work</a>,
we know that under the linear separation assumption, there has to be a feature
that gives nontrivial information.
To formalize it, we use the notion of
<a href="https://en.wikipedia.org/wiki/Decision_stump">decision stumps</a> and
<a href="https://en.wikipedia.org/wiki/Boosting_(machine_learning)">weak learners</a>.
A decision stump is a (simple) hypothesis of the form $sign(x_i-\theta)$ defined
by a feature $i$ and a threshold $\theta$.
A hypothesis class is a $\gamma$-weak learner if one can learn it with accuracy
$\gamma$ (slightly) better than random, i.e., if there is always a
hypothesis in the class with accuracy of at least $1/2+\gamma$.</p>
<p>Now, we look at the hypothesis class of all possible decision stumps, and we want
to show that this class is a weak learner.
For each dataset $S=((x^1,y^1),\ldots,(x^m,y^m))$, we denote
the best decision stump for this dataset by $h_S(x)=sign(x_i-\theta)$, where $i$
is a feature and $\theta$ is a threshold that minimize the error
$\sum_{j=1}^m sign(x^j_i < \theta) y^j.$
We can show that $h_S$ has accuracy better than $0.5$, i.e., better than a
random guess:</p>
<div class="theorem">
Fix $\alpha>0$.
For any distribution $\mu$ over $[-1,+1]^d\times\{-1,+1\}$ that satisfies
linear separability with a $\gamma$-margin, and for any $\delta\in(0,1)$ there
is $m=O\left(\frac{d+\log\frac1\delta}{\gamma^2}\right)$, such that with
probability at least $1-\delta$ over the sample $S$ of size $m$, it holds that
$$\Pr_{(x,y)\sim\mu}(h_S(x)=y)\geq \frac12+\frac{\gamma}{4}-\alpha.$$
</div>
<p>This result proves that there exists a classifier $h_S$ in the hypothesis class of
all possible decision stumps that produces a non-trivial
solution under the linear separability assumption.
Using this theorem along with the result from
<a href="https://www.sciencedirect.com/science/article/pii/S0022000097915439">Kearns and Mansour</a>,
we can show that
<a href="https://onlinelibrary.wiley.com/doi/full/10.1002/widm.8?casa_token=O2ehHd8cYlwAAAAA%3AplOtiUnZ41vnEXcvBTZiQxwPJfl1DTFB4ROZX8fX7VP0uXhyxJoqXmRKAIdUyaXRHe7EP1Y860w38A">CART</a>-type
algorithms can deliver a <strong>small</strong> tree with high accuracy.
As a side benefit, this is the <inc>first</inc> time that a distributional
assumption that does not include feature independence is used.
Many papers on theoretical guarantees of decision trees assumed either uniformity or feature independence
(papers <a href="https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.129.6343&rep=rep1&type=pdf">1</a>,
<a href="https://arxiv.org/abs/1911.07375">2</a>, and
<a href="http://proceedings.mlr.press/v125/brutzkus20a/brutzkus20a.pdf">3</a>).</p>
<p>Are we done? Is this model also robust?</p>
<h2 id="new-algorithm-bbm-rs">New algorithm: BBM-RS</h2>
<p>Designing robust decision trees is inherently a difficult task.
One reason is that, generally, the models defined by the right and left subtrees
can be completely different.
The feature $i$ in the root determines if the model
uses the right or left subtree.
Thus, a small change in the $i$-th feature completely changes the model.
To overcome this difficulty, we focus on a specific class of decision trees.
<!--**Note that** in the decision tree that corresponds to the risk score, the right
and left subtrees are the same.--></p>
<h3 id="risk-score">Risk score</h3>
<!--For decision trees, each inner node corresponds to a threshold and a
feature and each leaf corresponds to a label.
The label of an example is the leaf’s label in the corresponding path.
We focus on binary classification problems in this post.
In [our paper](https://arxiv.org/abs/2102.07048), we construct
a specific kind of decision tree ---
[risk scores](https://jmlr.org/papers/volume20/18-615/18-615.pdf).-->
<p>We design our algorithm to learn a specific kind of decision tree —
<a href="https://jmlr.org/papers/volume20/18-615/18-615.pdf">risk score</a>.
A risk score is composed of several conditions (e.g., $age \geq 75$), and each
is matched with an integer weight.
A score $s(x)$ of example $x$ is the weighted sum of all the satisfied
conditions.
The label is then $sign(s(x))$.</p>
<div style="width: 100%; overflow-x: auto; margin-bottom: 35pt">
<table style="width: 90%; font-size: 80%; margin: auto;" class="concise-table">
<caption>
<!--Two risk score models:
<a href='https://jmlr.org/papers/volume20/18-615/18-615.pdf'>LCPA</a> and our new BBM-RS algorithm on the
<a href='https://core.ac.uk/download/pdf/55631291.pdf'>bank dataset</a>.-->
An example of the risk score model on the <a href="https://core.ac.uk/download/pdf/55631291.pdf">bank dataset</a>.
Each satisfied condition is multiplied by its weight and summed. Bias term is always satisfied.
If the total score $>0$, the risk model predicts "1" (i.e., the client will open a bank account after a marketing call).
All features are binary (either $0$ or $1$).
For a concrete example, a person with age greater than 75, called before but the previous call
was not successful, and the consumer price index is greater than 93.5, the total score would be
$1$ and the prediction would be "1".
</caption>
<tr>
<th colspan="1">features</th>
<th colspan="2" style="text-align: center;">weights</th>
</tr>
<tr>
<th>Bias term</th> <th style="text-align: center;">-5</th> <th> + ... </th>
</tr>
<tr>
<th>Age $\geq 75$</th> <th style="text-align: center;">2</th> <th> + ... </th>
</tr>
<tr>
<th>Called before</th> <th style="text-align: center;">4</th> <th> + ... </th>
</tr>
<tr>
<th>Previous call was successful</th> <th style="text-align: center;">2</th> <th> + ... </th>
</tr>
<tr>
<th></th> <th>total scores=</th> <th> </th>
</tr>
</table>
</div>
<p>A risk score can be viewed as a decision tree with the same feature-threshold pair at
each level (see example below).
A risk score has simpler structure than a standard decision tree,
and it generally has fewer number of <em>unique</em> nodes.
Hence, they are considered
<a href="https://jmlr.org/papers/volume20/18-615/18-615.pdf">more interpretable than decision trees</a>.
The following table shows an example of a risk score.</p>
<div style="width: 100%; overflow-x: auto; margin-bottom: 20pt">
<table style="width: 100%; font-size: 80%; margin-bottom: 2pt;" class="concise-table">
<caption>
Here is an example of how to convert a risk score into a decision tree.
The table on the left is an example of a risk score that may be used by a doctor to determine
whether a patient caught a cold or not.
It has three conditions and the figure on the right is the corresponding decision tree.
For each node in the tree, the branch towards the right represents the path to take if the condition is true.
The leaves represent the final risk score of the given condition.
For a concrete example, if a patient has a fever, coughs, but does not sneeze, we would follow the green
path in the decision tree and result in a score of $2$.
</caption>
</table>
<span style="width: 49%; overflow-x: auto; display: inline-block; margin: 10pt; float: left;">
<table style="width: 100%; font-size: 80%; margin: auto;" class="concise-table">
<tr>
<th colspan="1" style="text-align: left;">features</th>
<th colspan="1" style="text-align: center;">weights</th>
<th colspan="1"></th>
</tr>
<tr><th>Bias term</th> <th style="text-align: center;">-3</th> <th> + ... </th></tr>
<tr><th>Fever</th> <th style="text-align: center;">3</th> <th> + ... </th></tr>
<tr><th>Sneeze</th> <th style="text-align: center;">1</th> <th> + ... </th></tr>
<tr><th>Cough</th> <th style="text-align: center;">2</th> <th> + ... </th></tr>
<tr><th></th> <th>total scores=</th> <th></th></tr>
</table>
</span>
<span style="width: 47%; overflow-x: auto; display: inline-block; margin: auto; float: right;">
<img src="/assets/2021-09-24-interpretable-robust-trees/risk_score_tree.png" style="width: 100%" />
</span>
</div>
<h3 id="bbm-rs">BBM-RS</h3>
<p>We design a new algorithm for learning risk scores by utilizing the known
boosting method
<a href="https://link.springer.com/content/pdf/10.1023/A:1010852229904.pdf">boost-by-majority</a>
(BBM).
The different conditions are added to the risk score one by one, using
the weak learner.
BBM has the benefit of ensuring the weights in the risk score
are small integers.
This will lead to an interpretable model with size only
$O(\gamma^{-2}\log1/\epsilon)$ where the model has accuracy $1-\epsilon$.</p>
<p>Now we want to make sure that the risk model is also robust.
The idea is to add
noise.
We take each point in the sample and just make sure that it’s a little
bit closer to the decision boundary, see the figure below.</p>
<p style="text-align: center;"><img src="/assets/2021-09-24-interpretable-robust-trees/BBM_RS_add_noise.png" /></p>
<p>The idea is that if the model is correct for the noisy point, then it
should be correct for the point without the noise.
To formally prove it, we show that choosing the risk-score conditions in a specific
way ensures that they are monotone models.
In such models, adding noise in the way we described is
sufficient for robustness.</p>
<p>Before we examine this algorithm on real datasets, let’s check its running time.
We focus on the case the margin and desired accuracy are constants.
In this case, the number of steps BBM-RS will take is also constant.
In each step, we run the weak learner and find the best $(i,\theta)$.
So the overall time is linear (up to logarithmic factors) in the input size and the time to run the
weak learner.</p>
<p>To summarize, we designed a new efficient algorithm, BBM-RS, that is robust, interpretable, and
has high accuracy. The following theorem shows this. Please refer to
<a href="https://arxiv.org/abs/2102.07048">our paper</a> for the pseudocode of BBM-RS
and more details for the theorem.</p>
<div class="theorem">
Suppose data is $\gamma$-linearly separable and fix $\epsilon,\delta\in(0,1)$.
Then, with probability $1-\delta$ the output of BBM-RS, after receiving
$(d+\log(1/\delta))\log(1/\epsilon)\gamma^{-O(1)}$ samples, has astuteness
$1-\epsilon$ at radius $\gamma/2$ and has $O(\gamma^{-2}\log(1/\epsilon))$
feature-threshold pairs.
</div>
<h3 id="performance-on-real-data">Performance on real data</h3>
<p>For BBM-RS, our theorem is restricted to linearly separable data.
However, real datasets may not perfectly linearly separable.
A straightforward question: is linear separability a reasonable
assumption in practice?</p>
<p>To answer this question, we consider $13$ real datasets (here we present the
results for four datasets; for more datasets, please refer to <a href="https://arxiv.org/abs/2102.07048">our
paper</a>).
We measure how linearly separable each of these datasets is.
We define the <strong>linear separateness</strong> as one minus the minimal fraction
of points that needed to be removed for the data to be linearly separable.
Since finding the optimal linear separateness on arbitrary data
<a href="https://www.sciencedirect.com/science/article/pii/S0022000003000382">is NP-hard</a>,
we approximate linear separateness with the training accuracy of the best linear classifier we can
find (since removing the incorrect examples for a linear classifier would make the dataset linearly
separable).
We train linear SVMs with different regularization parameters and record the best training accuracy.
After removing the misclassified points by an SVM, we are left with accuracy
fraction of linearly separable examples.
The higher this accuracy is, the more linearly separable the data is.
The following table shows the results and it reveals that most datasets
are very or moderately close to being linearly separated.
This indicates that the linear assumption in our theorem may not be too
restrictive in practice.</p>
<div style="width: 100%; overflow-x: auto;">
<table style="width: 50%; font-size: 80%; margin: auto" class="concise-table">
<tr>
<th colspan="1"></th> <th colspan="1" style="text-align: center;">linear separateness</th>
</tr>
<tr>
<th colspan="1">adult</th> <th colspan="1" style="text-align: center;">0.84</th>
</tr>
<tr>
<th colspan="1">breastcancer</th> <th colspan="1" style="text-align: center;">0.97</th>
</tr>
<tr>
<th colspan="1">diabetes</th> <th colspan="1" style="text-align: center;">0.77</th>
</tr>
<tr>
<th colspan="1">heart</th> <th colspan="1" style="text-align: center;">0.89</th>
</tr>
</table>
</div>
<p>Even though these datasets are not perfectly linearly separable, BBM-RS can
still be applied (but the theorem may not hold).
We are interested to see how BBM-RS performed against others on these
non-linearly separable datasets.
We compare BBM-RS to three baselines,
<a href="https://arxiv.org/abs/1610.00168">LCPA</a>,
<a href="https://books.google.co.il/books?hl=en&lr=&id=MGlQDwAAQBAJ&oi=fnd&pg=PP1&ots=gBmdjTJVdK&sig=\_jUBiPW4cTS7JYUKpzKcJLYipl4&redir_esc=y#v=onepage&q&f=false">decision tree (DT)</a>, and
<a href="http://proceedings.mlr.press/v97/chen19m/chen19m.pdf">robust decision tree (RobDT)</a>.
We measure a model’s robustness by evaluating its
<a href="https://arxiv.org/abs/2003.02460"><strong>Empirical robustness (ER)</strong></a>, which is the
average $\ell_\infty$
distance to the closest adversarial example on correctly predicted test examples.
The larger ER is, the more robust the classifier is.
We measure a model’s interpretability by evaluating its
<strong>interpretation complexity (IC)</strong>.
We measure IC with the number of unique feature-threshold pairs in the
model (this corresponds to the number of conditions in the risk score).
The smaller IC is, the more interpretable the classifier is.
The following tables show the experimental results.</p>
<div style="width: 100%; overflow-x: auto;">
<table style="width: 60%; font-size: 80%; margin: auto" class="concise-table">
<tr>
<th colspan="1"></th>
<th colspan="4" style="text-align: center;">test accuracy (higher=better)</th>
</tr>
<tr>
<th colspan="1"></th>
<th colspan="1" style="text-align: center;">DT</th>
<th colspan="1" style="text-align: center;">RobDT</th>
<th colspan="1" style="text-align: center;">LCPA</th>
<th colspan="1" style="text-align: center;">BBM-RS</th>
</tr>
<tr>
<th colspan="1">adult</th>
<th colspan="1" style="text-align: center; color: green;"><b>0.83</b></th>
<th colspan="1" style="text-align: center; color: green;"><b>0.83</b></th>
<th colspan="1" style="text-align: center;">0.82</th>
<th colspan="1" style="text-align: center;">0.81</th>
</tr>
<tr>
<th colspan="1">breastcancer</th>
<th colspan="1" style="text-align: center;">0.94</th>
<th colspan="1" style="text-align: center;">0.94</th>
<th colspan="1" style="text-align: center; color: green;"><b>0.96</b></th>
<th colspan="1" style="text-align: center; color: green;"><b>0.96</b></th>
</tr>
<tr>
<th colspan="1">diabetes</th>
<th colspan="1" style="text-align: center;">0.74</th>
<th colspan="1" style="text-align: center;">0.73</th>
<th colspan="1" style="text-align: center; color: green;"><b>0.76</b></th>
<th colspan="1" style="text-align: center;">0.65</th>
</tr>
<tr>
<th colspan="1">heart</th>
<th colspan="1" style="text-align: center;">0.76</th>
<th colspan="1" style="text-align: center;">0.79</th>
<th colspan="1" style="text-align: center; color: green;"><b>0.82</b></th>
<th colspan="1" style="text-align: center; color: green;"><b>0.82</b></th>
</tr>
</table>
</div>
<div style="width: 100%; overflow-x: auto;">
<table style="width: 60%; font-size: 80%; margin: auto" class="concise-table">
<tr>
<th colspan="1"></th>
<th colspan="4" style="text-align: center;">ER (higher=better)</th>
</tr>
<tr>
<th colspan="1"></th>
<th colspan="1" style="text-align: center;">DT</th>
<th colspan="1" style="text-align: center;">RobDT</th>
<th colspan="1" style="text-align: center;">LCPA</th>
<th colspan="1" style="text-align: center;">BBM-RS</th>
</tr>
<tr>
<th colspan="1">adult</th>
<th colspan="1" style="text-align: center; color: green;"><b>0.50</b></th>
<th colspan="1" style="text-align: center; color: green;"><b>0.50</b></th>
<th colspan="1" style="text-align: center;">0.12</th>
<th colspan="1" style="text-align: center; color: green;"><b>0.50</b></th>
</tr>
<tr>
<th colspan="1">breastcancer</th>
<th colspan="1" style="text-align: center;">0.23</th>
<th colspan="1" style="text-align: center; color: green;"><b>0.29</b></th>
<th colspan="1" style="text-align: center;">0.28</th>
<th colspan="1" style="text-align: center;">0.27</th>
</tr>
<tr>
<th colspan="1">diabetes</th>
<th colspan="1" style="text-align: center;">0.08</th>
<th colspan="1" style="text-align: center;">0.08</th>
<th colspan="1" style="text-align: center;">0.09</th>
<th colspan="1" style="text-align: center; color: green;"><b>0.15</b></th>
</tr>
<tr>
<th colspan="1">heart</th>
<th colspan="1" style="text-align: center;">0.23</th>
<th colspan="1" style="text-align: center;">0.31</th>
<th colspan="1" style="text-align: center;">0.14</th>
<th colspan="1" style="text-align: center; color: green;"><b>0.32</b></th>
</tr>
</table>
</div>
<div style="width: 100%; overflow-x: auto;">
<table style="width: 60%; font-size: 80%; margin: auto" class="concise-table">
<tr>
<th colspan="1"></th>
<th colspan="4" style="text-align: center;">IC feature threshold pairs (lower=better)</th>
</tr>
<tr>
<th colspan="1"></th>
<th colspan="1" style="text-align: center;">DT</th>
<th colspan="1" style="text-align: center;">RobDT</th>
<th colspan="1" style="text-align: center;">LCPA</th>
<th colspan="1" style="text-align: center;">BBM-RS</th>
</tr>
<tr>
<th colspan="1">adult</th>
<th colspan="1" style="text-align: center;">414.20</th>
<th colspan="1" style="text-align: center;">287.90</th>
<th colspan="1" style="text-align: center;">14.90</th>
<th colspan="1" style="text-align: center; color: green;"><b>6.00</b></th>
</tr>
<tr>
<th colspan="1">breastcancer</th>
<th colspan="1" style="text-align: center;">15.20</th>
<th colspan="1" style="text-align: center;">7.40</th>
<th colspan="1" style="text-align: center; color: green;"><b>6.00</b></th>
<th colspan="1" style="text-align: center;">11.00</th>
</tr>
<tr>
<th colspan="1">diabetes</th>
<th colspan="1" style="text-align: center;">31.20</th>
<th colspan="1" style="text-align: center;">27.90</th>
<th colspan="1" style="text-align: center;">6.00</th>
<th colspan="1" style="text-align: center; color: green;"><b>2.10</b></th>
</tr>
<tr>
<th colspan="1">heart</th>
<th colspan="1" style="text-align: center;">20.30</th>
<th colspan="1" style="text-align: center;">13.60</th>
<th colspan="1" style="text-align: center;">11.90</th>
<th colspan="1" style="text-align: center; color: green;"><b>9.50</b></th>
</tr>
</table>
</div>
<p>From the tables, we see that BBM-RS has a test accuracy comparable to other
methods.
In terms of robustness, it performs slightly better than others (performing the
best on three datasets among a total of four).
In terms of interpretability, BBM-RS
performs the best in three out of four datasets.
All in all, we see that BBM-RS can bring better robustness and interpretability
while performing competitively on test accuracy.
This shows that BBM-RS not only performs well theoretically, it also performs
well empirically.</p>
<h2 id="conclusion">Conclusion</h2>
<p>We investigated three important properties of a classifier: accuracy, robustness, and
interpretability.
We designed and analyzed a tree-based algorithm that provably achieves all these properties, under
linear separation with a margin assumption.
Our research is a step towards building trustworthy models that provably achieve many desired
properties.</p>
<p>Our research raises many open problems.
What is the optimal dependence between accuracy, interpretation complexity,
empirical robustness, and sample complexity?
Can we have guarantees using different notions of interpretability?
We showed how to construct an interpretable, robust, and accurate model. But,
for reliable machine learning models, many more properties are required,
such as privacy and fairness.
Can we build a model with guarantees on all these properties simultaneously?</p>
<h4 id="more-details">More Details</h4>
<p>See <a href="https://arxiv.org/abs/2102.07048">our paper on arxiv</a> or <a href="https://github.com/yangarbiter/interpretable-robust-trees">our repository</a>.</p><a href='https://sites.google.com/view/michal-moshkovitz'>Michal Moshkovitz</a> and <a href='http://yyyang.me'>Yao-Yuan Yang</a>Trustworthy machine learning (ML) has emerged as a crucial topic for the success of ML models. This post focuses on three fundamental properties of trustworthy ML models -- high accuracy, interpretability, and robustness. Building on ideas from ensemble learning, we construct a tree-based model that is guaranteed to be adversely robust, interpretable, and accurate on linearly separable data. Experiments confirm that our algorithm yields classifiers that are both interpretable, robust, and have high accuracy.Location Trace Privacy Under Conditional Priors2021-05-10T19:00:00+00:002021-05-10T19:00:00+00:00https://ucsdml.github.io//jekyll/update/2021/05/10/location-trace-privacy<p>Imagine a mobile app that repeatedly records your geolocation over a short period of time – say a day. We call this sequence of locations a <em>location trace</em>. Ideally, the app would like to use these locations to send recommendations or ads or even reminders. But there is the issue of privacy; many people including myself, would feel uncomfortable if our exact locations were to be shared with and recorded by apps. One option may be to completely shut off all location services. But is it possible to have a happy medium? In other words, can we obscure a location trace of an user while still providing some privacy?</p>
<h3 id="rigorous-privacy-definitions-differential-and-inferential">Rigorous Privacy Definitions: Differential and Inferential</h3>
<p>Before we get to what privacy means in this case, let us look at how rigorous privacy definitions work. Broadly speaking, the literature has two main philosophies of rigorous definitions of statistical privacy — differential and inferential privacy. Differential privacy is an elegant privacy definition designed by cryptographers Cynthia Dwork, Frank McSherry, Kobbi Nissim and Adam Smith in 2006. The philosophy here is that the participation of a single person in the data should not make a big difference to the probability of any outcome; this, in turn, implies that an adversary watching the output of a differentially private algorithm cannot determine for sure if a certain person is in the dataset or not. Differential privacy has many elegant properties — such as, robustness to auxiliary information, graceful composition and post processing invariance.</p>
<p>Inferential privacy in contrast means that an adversary with a certain prior knowledge does not gain a lot of extra knowledge after seeing the output of a private algorithm. While this notion is older than differential privacy, it was formalized by <a href="https://users.cs.duke.edu/~ashwin/pubs/pufferfish_TODS.pdf"> Kifer and Machanavajjhala in 2012 as the Pufferfish privacy framework</a>. Inferential privacy does not always have the elegant properties of differential privacy, but it tends to be more flexible in the sense that it can obscure some specific events. Besides, some inferential privacy frameworks or algorithms do have graceful composition and are robust to certain kinds of auxiliary information. There is a no-free-lunch theorem that states that inferential privacy against all manner of auxiliary information will imply no utility — and so there is a limit to how far this can extend.</p>
<h3 id="a-privacy-framework-for-location-traces">A Privacy Framework for Location Traces</h3>
<p>Coming back to the privacy of location traces, let us now think about some options on how to model them in a rigorous privacy framework. There are two interesting aspects about location traces. First, location is continuous spatial data — and for both privacy and utility, we may need to obscure it up to a certain distance. We call this the <em>spatiality aspect</em>. But the more challenging aspect is correlation. My location at 10am is highly correlated with my location at 10:05, and not building this into a privacy framework may lead to privacy leaks.</p>
<p>Our first option is to use local differential privacy (LDP), which is basically differential privacy applied to a single person’s data. This will mean that two traces — one in New York and one in California — will be almost indistinguishable. However, this involves adding considerable noise to each trace — so much so as to render them completely useless. We will have very good privacy, but almost no utility whatsoever.</p>
<p>Our second option is to realize that while most people may be uncomfortable sharing fine-grained location information, they may be okay with coarse-grained data. For example, since I work at UCSD, which is in La Jolla, CA, I may not mind someone knowing that I spend most of my working hours in La Jolla; but I would not want them to know my precise location. This is known as <em>geo-indistinguishability</em>, and is achieved by adding independent noise with a radius $r$ to each location. This improves utility, if we are releasing a single location, but still has challenges with traces. If we average the private locations at 10am and 10:05am, then we get a better estimate since the underlying true locations are highly correlated.</p>
<p style="text-align: center;">
<img src="/assets/2021-05-10-location-priv/plausible_solutions.png" width="80%" />
</p>
<h5 style="text-align: left;">Tradeoffs of three privacy definitions for location data: While DP prevents use of correlation, it does not allow for utility with individual traces. Geoindistinguishability works well for a single location, but cannot prevent an adversary from correlating points close by in time. Our definition (conditional inferential privacy) provides an intermediate: prevent inference against a class of priors while still offering valuable utility.</h5>
<h3 id="conditional-inferential-privacy">Conditional Inferential Privacy</h3>
<p>This brings us to our framework, Conditional Inferential Privacy (CIP). Here we aim to obscure each location to within a radius $r$, while taking into account correlation across time through a Gaussian Process prior. Gaussian processes effectively model a sequence of $n$ random variables as an $n$-dimensional vector drawn from a multivariate normal distribution (see <a href="http://www.gaussianprocess.org/gpml/chapters/RW2.pdf">Rasmussen Ch. 2</a> for more detail). In the location setting, the correlation between two locations increases with their proximity in time. Gaussian processes are frequently used to model trajectories (<a href="https://ieeexplore.ieee.org/document/7102794">Chen ‘15</a>, <a href="https://ieeexplore.ieee.org/document/1237448">Liang & Hass ‘03</a>, <a href="https://ieeexplore.ieee.org/document/709453">Liu ‘98</a>, <a href="https://ieeexplore.ieee.org/document/6126365">Kim ‘11</a>), so this serves as a good model for a prior. Through directly modeling correlations, we can ensure that we can obscure locations up to a radius $r$, even in the presence of these correlations.</p>
<p>Formally, our framework builds upon the Pufferfish inferential privacy framework. We have a set of basic secrets $S$ consisting of events $s_{x,t}$, which denotes “User was at location $x$ at time $t$”. These are the kinds of events that we would like to hide. In practice, we may choose to hide more complicated events — such as “User was at home at 10am and at the coffee shop at 10:05am”; these are modeled by a set of compound events $C$, which is essentially a set of tuples of the form $(s_{x_1, t_1}, s_{x_2, t_2}, …)$.</p>
<p>We then have the set of secret pairs $P$ which is a subset of $C \times C$ — these are the pairs of secrets that the adversary should not be able to distinguish between. Finally we have a set of priors $\Theta$, which is a set of Gaussian processes that presumably represents the adversary’s prior.</p>
<p>A mechanism $M$ is said to follow $(P, \Theta)$-CIP with parameters $(\lambda, \epsilon)$, if for all $\theta \in \Theta$ and all tuples in $(s, s’) \in P$, we have that:</p>
\[D_{\text{Renyi}, \lambda} \Big(\Pr(M(X) = Z | s, \theta ) , \Pr(M(X) = Z | s’, \theta)\Big) \leq \epsilon\]
<p>where $D_{\text{Renyi}, \lambda}$ is the Renyi divergence of order $\lambda$ (see <a href="https://arxiv.org/abs/1702.07476"> Mironov ‘17 </a> for background on Renyi divergence and its use in the privacy literature). Essentially what this means is that the distributions of the output of the mechanism $M$ are similar under the secret s and s’. Similar here means low Renyi divergence.</p>
<p>There are a couple of interesting things to note here. First, note that unlike differential privacy, here the privacy is over both the prior and the randomness in the mechanism; this is quite standard for inferential privacy. Second, observe that we use Renyi divergence in the definitions instead of the probability ratios or max divergence that is used in the standard differential privacy and Pufferfish privacy definition. This is because Renyi divergences have a natural synergy with Gaussians and Gaussian processes, which we use as our priors and mechanisms.</p>
<p>While not as elegant as differential privacy, this definition also has some good properties. We can show that we can get graceful decay of privacy for two trajectories of the same person from different time intervals — which is analogous to what is called parallel composition in the privacy literature. We also show that there is some robustness to side information. Details are in our paper.</p>
<p style="text-align: left;"><img src="/assets/2021-05-10-location-priv/three_traces.png" width="100%" /></p>
<h5 style="text-align: left;">Example of how CIP maintains high uncertainty at secret locations (times). Left: <a href="https://www.nytimes.com/interactive/2018/12/10/business/location-data-privacy-apps.html">a real location trace unknowningly collected from an NYC mayoral staff member by apps on their phone</a>. The red dots indicate sensitive locations. Middle: demonstration of how Geoindistinguishability (adding independent isotropic gaussian noise to each location, as in the red trace) allows for high certainty of true location by correlation. The green envelope shows the posterior uncertainty of a Bayesian adversary with a Gaussian process prior (a <em>GP adversary</em>). Right: demonstration of how a CIP mechanism efficiently thwarts the same adversary's posterior at sensitive locations, given the same utility budget. The mechanism achieves this by both concentrating the noise budget near sensitive locations and by strategically correlating noise added.</h5>
<h4 id="related-work">Related Work</h4>
<p>It is worth noting that we are in no way the first to attempt to offer meaningful location privacy. However, our method is distinguished in that it works in a continuous spatiotemporal domain, offers local privacy within a radius $r$ for sensitive locations, and has a semantically meaningful inferential guarantee. A mechanism offered by <a href="https://ieeexplore.ieee.org/document/7546522"> Bindschaedler & Shokri</a> releases synthesized traces satisfying the notion of plausible deniability, but this is distinctly different from providing a radius of privacy in the local setting, as we do. Meanwhile, the frameworks proposed by <a href="https://arxiv.org/abs/1410.5919">Xiao & Xiong (2015)</a> and <a href="http://export.arxiv.org/pdf/1810.09152">Cao et al. (2019)</a> nicely characterize the risk of inference in location traces, but use only first-order Markov models of correlation between points, do not offer a radius of indistinguishability as in this work, and are not suited to continuous-valued spatiotemporal traces.</p>
<h3 id="results">Results</h3>
<p>With the definition in place, we can now measure the privacy loss of different mechanisms. The most basic mechanism is to add zero-mean isotropic Gaussian noise with equal standard deviation to every location in the trace and publish the result; if the added noise has standard deviation $\sigma$, then we can calculate the privacy loss under CIP, as well as the mean square error utility. If a certain utility is desired, we can calibrate $\sigma$ to it and obtain a certain privacy loss.</p>
<p>A more sophisticated mechanism is to add zero-mean Gaussian noise with different covariances to locations at different time points. It turns out that we can choose the covariances to minimize privacy loss for a given utility, and this can be done by solving a Semi-Definite Program. The derivation and more details are in our paper.</p>
<p>We provide below a snap-shot of what our results look like. On the x-axis, we are plotting a measure of how correlated our prior is. If the prior is highly correlated, then it is easy to leak privacy for mechanisms that add noise — and hence correlated priors are worse for privacy. On the y-axis, we are plotting the posterior confidence interval size of the adversary — higher means higher privacy. Both mechanisms are calibrated to the same mean-square error, and hence the privacy-utility tradeoff is better if the y-axis is higher. From the figure, we see that our SDP-based mechanism does lead to a better privacy-utility tradeoff, and as expected, privacy offered declines as the correlations grow worse.</p>
<p style="text-align: left;"><img src="/assets/2021-05-10-location-priv/experiments.png" width="100%" /></p>
<h5 style="text-align: left;">Our inferentially private mechanism (blue line) maintains higher posterior uncertainty for a Bayesian adversary with a Gaussian process prior (a <em>GP adversary</em>) as compared to two Geoindistinguishability-based baselines (orange and green). The x-axis indicates the degree of correlation anticipated by the GP adversary. The left panel shows the posterior uncertainty for a single basic secret. The middle panel shows uncertainty for a compound secret. The right panel shows posterior uncertainty when we design our mechanism to maintain privacy at every location (all basic secrets). The gray window shows a range of realistic degrees of dependence (correlation) gathered from human mobility data. </h5>
<p style="text-align: left;"><img src="/assets/2021-05-10-location-priv/covariance.png" width="100%" /></p>
<h5 style="text-align: left;">Examples of the noise covariance chosen by our mechanism: Each frame is a covariance matrix optimized by our SDP mechanism to thwart inference at either a single location basic secret or a compound secret of two locations. Noise drawn from a multivariate normal with this covariance is added along the 50 point trace. The two frames on the left show covariance chosen to thwart a GP prior with an RBF kernel. The two frames on the right show covariance chosen to thwart a GP prior with a periodic kernel.</h5>
<h3 id="conclusion">Conclusion</h3>
<p>In conclusion, we take a stab at a long-standing challenge in offering location privacy — temporal correlations — and we provide a way to model them cleanly and flexibly through Gaussian Process priors. This gives us a way to quantify the privacy loss for correlated location trajectories and devise new mechanisms for sanitizing them. Our experiments show that our mechanisms offer better privacy-accuracy tradeoffs than standard baselines.</p>
<p>There are many open problems, particularly in the space of mechanism design. Can we improve the privacy-utility tradeoff offered by our mechanisms through other means, such as subsampling the traces or interpolation? Can we make our definition and our methods more robust to side information? Finally, location traces are only one example of correlated and structured data; a remaining challenge is to build upon the methodology developed here to design privacy frameworks for more complex and structured data.</p><a href='https://cseweb.ucsd.edu/~kamalika/'>Kamalika Chaudhuri</a> and Casey MeehanProviding meaningful privacy to users of location based services is particularly challenging when multiple locations are revealed in a short period of time. This is primarily due to the tremendous degree of dependence that can be anticipated between points. We propose a Renyi divergence based privacy framework, "Conditional Inferential Privacy", that quantifies this privacy risk given a class of priors on the correlation between locations. Additionally, we demonstrate an SDP-based mechanism for achieving this privacy under any Gaussian process prior. This framework both exemplifies why dependent data is so challenging to protect and offers a strategy for preserving privacy to within a fixed radius for sensitive locations in a user’s trace.The Expressive Power of Normalizing Flow Models2020-11-16T17:00:00+00:002020-11-16T17:00:00+00:00https://ucsdml.github.io//jekyll/update/2020/11/16/expressive-power-normalizing-flows<h3 id="background-generative-models-and-normalizing-flows">Background: Generative Models and Normalizing Flows</h3>
<p><a href="https://en.wikipedia.org/wiki/Generative_model">Generative models</a> are one kind of unsupervised learning model in machine learning. Given a set of training data – such as pictures of dogs, audio clips of human speakers, and articles from certain websites – a generative model aims to generate samples that look/sound like they are samples from the dataset, but are not exactly any one of them. We usually train a generative model by maximizing the probability, or likelihood, of the samples under the model.</p>
<p>To understand complicated training data, generative models usually use very large neural networks (so they are also called deep generative models). Popular deep generative models include <a href="https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf">generative adversarial networks</a> (GANs) and <a href="https://arxiv.org/pdf/1606.05908.pdf">variational autoencoders</a> (VAEs), which have achieved the state-of-the-art performances on most generative tasks. Below are examples showing that <a href="https://arxiv.org/abs/1812.04948">styleGAN</a> (left) and <a href="https://arxiv.org/abs/1906.00446">VQ-VAE</a> (right) can generate amazing high resolution images!</p>
<p style="text-align: center;"><img src="/assets/2020-11-16-nf/stylegan_demo.png" width="41%" />
<img src="/assets/2020-11-16-nf/vqvae_demo.png" width="55.2%" /></p>
<p>One might ask: as we already have powerful generative models, is everything done? No! There are many aspects in which we want to improve these models. Below are two points related to this blog.</p>
<p>First, we want to compute exact likelihood if possible. Both GANs and VAEs generate samples by applying a neural network transformation on a latent random variable $z$, which is usually a Gaussian. In this case, the sample likelihood <i> cannot </i> be exactly computed because complicated neural networks may map different $z$’s to the same output.</p>
<p>This is the reason why <a href="https://arxiv.org/abs/1908.09257">normalizing flows</a> (NFs) were proposed. An NF learns an <b>invertible</b> function $f$ (which is also a neural network) to convert a source distribution, such as a Gaussian, to the distribution of the training data. Since $f$ is invertible, we can <i> precisely </i> compute the likelihood through the change-of-variable formula! <a href="http://akosiorek.github.io/ml/2018/04/03/norm_flows.html">This post</a> includes the detailed math of the computation. Different from the decoder in VAEs and the generator in GANs (which usually transform a lower dimensional latent variable to the data distribution), the NF $f$ keeps the data dimension and $f^{-1}$ can map a sample back to the source distribution.</p>
<p>Second, we want a theoretical guarantee that these deep generative models are <i> potentially </i> able to learn an arbitrarily complicated data distribution. Without such theory, an <i> empirically </i> successful generative model might fail in another scenario, and we don’t want this risk to always exist! Despite its importance, this problem is super challenging due to the complicated structure of neural networks. For example, <a href="https://papers.nips.cc/paper/2018/file/9bd5ee6fe55aaeb673025dbcb8f939c1-Paper.pdf">this paper</a> analyzes GANs in transforming between very simple distributions.</p>
<p>This blog addresses the above two points by making a theoretical analysis to NFs. We provide a theoretical guarantee for NFs on $\mathbb{R}$ and some negative (impossibility) results for NFs on $\mathbb{R}^d$ where the dimension $d>1$.</p>
<h3 id="structure-of-normalizing-flows">Structure of Normalizing Flows</h3>
<p>In general, to model complex training data like images, the normalizing flow $f$ needs to be a very complicated function. In practice, $f$ is usually constructed via a sequence of simple, invertible transformations, which we call base flow layers. The figure below illustrates the middle stages within the transformation from a simple source distribution to a complicated target distribution (figure from <a href="https://lilianweng.github.io/lil-log/2018/10/13/flow-based-deep-generative-models.html">this link</a>).</p>
<p style="text-align: center;"><img src="/assets/2020-11-16-nf/nf_model.png" width="80%" /></p>
<p>Examples of base flow layers include</p>
<ul>
<li>
<p><a href="https://arxiv.org/abs/1908.09257">planar layers</a>: $f_{\text{pf}}(z)=z+uh(w^{\top}z+b)$, where $u,w,z\in\mathbb{R}^d,b\in\mathbb{R}$;</p>
</li>
<li>
<p><a href="https://arxiv.org/abs/1908.09257">radial layers</a>: $f_{\text{rf}}(z)=z+\frac{\beta}{\alpha+\|z-z_0\|}(z-z_0)$, where $z,z_0\in\mathbb{R}^d,\alpha,\beta\in\mathbb{R}$;</p>
</li>
<li>
<p><a href="https://arxiv.org/abs/1803.05649">Sylvester layers</a>: $f_{\text{syl}}(z)=z+Ah(B^{\top}z+b)$, where $A,B\in\mathbb{R}^{d\times m}, z\in\mathbb{R}^d, b\in\mathbb{R}^m$;</p>
</li>
<li>
<p>and <a href="https://arxiv.org/abs/1611.09630">Householder layers</a>: $f_{\text{hh}}(z)=z-2vv^{\top}z$, where $v,z\in\mathbb{R}^d, v^{\top}v=1$.</p>
</li>
</ul>
<p>The number of layers is usually very large in practice. For instance, in the MNIST dataset experiments, <a href="https://arxiv.org/abs/1908.09257">this paper</a> uses 80 planar layers, and <a href="https://arxiv.org/abs/1803.05649">this paper</a> uses 16 Sylvester layers.</p>
<h3 id="defining-the-expressivity-of-normalizing-flows">Defining the Expressivity of Normalizing Flows</h3>
<p>The invertibility of NFs may hugely restrict their expressive power, but to what extent? Our <a href="http://proceedings.mlr.press/v108/kong20a/kong20a.pdf">recent paper</a> analyzes this through the following two questions:</p>
<ul>
<li>
<p><b>Q</b>1 (Exact transformation): Under what conditions is it possible to <b>exactly</b> transform the source distribution $q$ (e.g., a standard Gaussian) into the target distribution $p$ with a finite number of base flow layers?</p>
</li>
<li>
<p><b>Q</b>2 (Approximation): Since sometimes exact transformation may be hard, when is it possible to <b>approximate</b> the target distribution $p$ in <a href="https://en.wikipedia.org/wiki/Total_variation_distance_of_probability_measures">total variation distance</a>? Do we need an incredibly large number of layers?</p>
</li>
</ul>
<p>Our findings:</p>
<ul>
<li>
<p>If $p$ and $q$ are defined on $\mathbb{R}$, then universal approximation can be achieved. That is, we can always transform $q$ to be arbitrarily close to any $p$.</p>
</li>
<li>
<p>If $p$ and $q$ are defined on $\mathbb{R}^d$ where $d>1$, both exact transformation and approximation may be hard. Having a large number of layers is a necessary (but not a sufficient) condition.</p>
</li>
</ul>
<h3 id="challenges">Challenges</h3>
<p>Our problem is very related to the universal approximation property: the ability of a function class to be arbitrarily close to any target function. Although we have this property for <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.441.7873&rep=rep1&type=pdf">shallow neural networks</a>, <a href="https://arxiv.org/abs/1709.02540">fully connected networks</a>, and <a href="https://arxiv.org/abs/1806.10909">residual networks</a>, these results do not apply to NFs. Why? Because of the <b>invertibility</b>.</p>
<ul>
<li>
<p>First, a function class has the universal approximation property does <b>not</b> imply that its invertible subset can approximate between any pair of distributions. For instance, take the set of piecewise constant functions. Its invertible subset is the empty set!</p>
</li>
<li>
<p>On the other hand, a function class has limited capacity does <b>not</b> imply that its invertible subset <b>cannot</b> transform between any pair of distributions. For instance, take the set of triangular maps, which can perform powerful Knothe–Rosenblatt rearrangements (See page 17 of <a href="https://ljk.imag.fr/membres/Emmanuel.Maitre/lib/exe/fetch.php?media=b07.stflour.pdf">this book</a>).</p>
</li>
</ul>
<p><b>The way to get around this challenge:</b> instead of looking at the capacity of a function class in the function space, we directly analyze input–output distribution pairs.</p>
<h3 id="universal-approximation-when-d1">Universal Approximation When $d=1$</h3>
<p>As warm-up let us look at the one-dimensional case. We show planar layers can approximate between arbitrary pairs of distributions under mild assumptions. We analyze a specific kind of planar layer with the ReLU activation:
\[f_{\text{pf}}(z)=z+u\ \mathrm{ReLU}(wz+b)\]
where $u,w,b,z\in\mathbb{R}$, and $\text{ReLU}(x)=\max(x,0)$. The effect of this transformation on a density is first splitting its graph into two pieces, and then scaling one piece while keeping the other one unchanged. For example, in the figure below the first planar layer splits the blue line into the solid part and the dashed part, and scales the dashed part to the orange line. Similarly, the second planar layer splits the orange line into the solid part and the dashed part, and scales the dashed part to the green line.</p>
<p style="text-align: center;"><img src="/assets/2020-11-16-nf/tail_consistent_pwg.png" width="60%" /></p>
<p>In particular, if the blue line is Gaussian, then the orange line and the green line are also pieces of some Gaussian distributions. We call this a piecewise Gaussian distribution. Additionally, it has the consistency property: the integration of the transformed distribution should always be 1.</p>
<p>How does it relate to approximation? Here we use a fundamental result in real analysis: <a href="https://en.wikipedia.org/wiki/Lebesgue_integration">Lebesgue-integrable functions</a> can be approximated by piecewise constant functions. Given a piecewise constant distribution $q_{\text{pwc}}$ that is close to the target distribution $p$, we can iteratively construct a piecewise Gaussian distribution $q_{\text{pwg}}$ with the same group of pieces. We can additionally require $q_{\text{pwg}}$ to be very close to $q_{\text{pwc}}$ by carefully selecting the parameters $u,w,b$. Finally, as the pieces become smaller, $q_{\text{pwc}}\rightarrow p$ and $q_{\text{pwg}}\rightarrow q_{\text{pwc}}$, which implies $q_{\text{pwg}}\rightarrow p$.</p>
<p>In the following example, we demonstrate such approximation with 50(top) and 300(bottom) ReLU planar layers, respectively.</p>
<p style="text-align: center;"><img src="/assets/2020-11-16-nf/1d_ReLU_50.png" width="60%" />
<img src="/assets/2020-11-16-nf/1d_ReLU_300.png" width="60%" /></p>
<h3 id="exact-transformation-when-d1">Exact Transformation When $d>1$</h3>
<p>Next, we look at the more general case in higher-dimensional space, which is usually quite different from the one-dimensional case. We show exact transformation between distributions can be quite hard. Specifically, we analyze Sylvester layers, a matrix-form generalization of planar layers (note that on $\mathbb{R}$, planar layers and Sylvester layers are equivalent):
\[f_{\text{syl}}(z)=z+Ah(B^{\top}z+b)\]
where $A,B\in\mathbb{R}^{d\times m},z\in\mathbb{R}^d,b\in\mathbb{R}^m$ for some integer $m$. In particular, we call $m$ the number of neurons of $f_{\text{syl}}$ because its form is identical to a residual block with $m$ neurons in the hidden layer.</p>
<p>Now suppose we stack a number of Sylvester layers with $M$ neurons in total, and these layers sequentially transform an input distribution $q$ to output distribution $p$. For convenience, let $f$ be the function composed of all these Sylvester layers. We show that the distribution pairs $(q,p)$ must obey some necessary (but not sufficient) condition, which we call the <b>topology matching</b> condition.</p>
<ul>
<li><b>$h$ is a smooth function</b></li>
</ul>
<p>Let $L(z)=\log p(f(z))-\log q(z)$ be the log-det Jacobian term. Then, the topology matching condition says the dimension of the set of the gradient of $L$ is no more than the number of neurons. Formally,
\[\dim\{\nabla_z L(z):z\in\mathbb{R}^d\}\leq M\]
In other words, if $M$ is less than the above dimensionality then exact transformation is impossible no matter what smooth non-linearities $h$ are selected.
Since it is not easy to plot $\{\nabla_z L(z):z\in\mathbb{R}^d\}$, we demonstrate $L(z)$ in a few examples below. Each row is a group, containing plots of $q$, $p$, and $L$ from left to right. In these examples, $M=1$ so $\nabla_z L(z)$ is a multiple a constant vector.</p>
<p style="text-align: center;">→ <img src="/assets/2020-11-16-nf/general_topo_1.png" width="60%" /><br /><br />
→ <img src="/assets/2020-11-16-nf/general_topo_2.png" width="60%" /><br /><br />
→ <img src="/assets/2020-11-16-nf/general_topo_3.png" width="60%" /><br /><br />
→ <img src="/assets/2020-11-16-nf/general_topo_4.png" width="60%" /><br /><br /></p>
<p>Based on the topology matching condition, it can be shown that if the number of neurons $M$ is less than the dimension $d$, it may even be hard to transform between simple Gaussian distributions.</p>
<ul>
<li><b>When $h=\text{ReLU}$</b></li>
</ul>
<p>We then restrict to ReLU Sylvester layers. In this case, $f$ in fact performs a piecewise linear transformation in $\mathbb{R}^d$. As a result, for almost every $z\in\mathbb{R}^d$ (except for boundary points), $f$ is linear around $z$. This leads to the following (pointwise) topology matching condition: there exists a constant matrix $C$ (which is the Jacobian matrix of $f(z)$) around $z$ such that
\[C^{\top}\nabla_z\log p(f(z))=\nabla_z\log q(z)\]</p>
<p>We demonstrate this result with two examples below, where each row is a $(q,p)$ distribution pair. The red points ($z$) on the left are transformed to those ($f(z)$) on the right by $f$. Notice that these red points are peaks of $q$ and $p$, respectively. In these cases, both $\nabla_z\log p(f(z))$ and $\nabla_z\log q(z)$ are zero vectors, which is compatible with the topology matching condition.</p>
<p style="text-align: center;">→ <img src="/assets/2020-11-16-nf/ReLU_topo_1.png" width="60%" /><br /><br />
→ <img src="/assets/2020-11-16-nf/ReLU_topo_2.png" width="60%" /><br /><br /></p>
<p>As a corollary, we conclude that ReLU Sylvester layers generally do not transform between product distributions or mixture of Gaussian distributions except for very special cases.</p>
<h3 id="approximation-capacity-when-d1">Approximation Capacity When $d>1$</h3>
<p>It is not surprising that exact transformation between distributions is difficult. What if we loosen our goal to approximation between distributions, where we can use transformations from a certain class $\mathcal{F}$? We show that unfortunately, this is still hard under certain conditions.</p>
<p>The way to look at this problem is to bound the minimum depth that is needed to approximate between $q$ and $p$. In other words, if we use less than this number of transformations, then it is impossible to approximate $p$ given $q$ as the source, no matter what transformations in $\mathcal{F}$ are selected. Formally, for $\epsilon>0$, we define the minimum depth as
\[T_{\epsilon}(p,q,\mathcal{F})=\inf\{n: \exists \{f_i\}_{i=1}^n\in\mathcal{F}\text{ such that }\mathrm{TV}((f_1\circ\cdots\circ f_n)(q),p)\leq\epsilon\}\]
where $\mathrm{TV}$ is the total variance distance.</p>
<p>We conclude that if $\mathcal{F}$ is the set of $(i)$ planar layers $f_{\text{pf}}$ with bounded parameters and popular non-linearities including $\tanh$, sigmoid, and $\arctan$, or $(ii)$ all Householder layers $f_{\text{hh}}$, then $T_{\epsilon}(p,q,\mathcal{F})$ is not small. In detail, for any $\kappa>0$, there exists a pair of distributions $(q,p)$ on $\mathbb{R}^d$ and a constant $\epsilon$ (e.g., 0.5) such that
\[T_{\epsilon}(p,q,\mathcal{F})=\tilde{\Omega}(d^{\kappa})\]
Although this lower bound is polynomial in the dimension $d$, in many practical problems the dimension can be very large so the minimum depth is still an incredibly large number. This result tells us that planar layers and Householder layers are provably not very expressive under certain conditions.</p>
<h3 id="open-problems">Open Problems</h3>
<p>This is the end of <a href="http://proceedings.mlr.press/v108/kong20a/kong20a.pdf">our paper</a>, but is clearly just the beginning of the story. There are a large number of open problems on the expressive power of even simple normalizing flow transformations. Below are some potential directions.</p>
<ul>
<li>Just like neural networks, planar and Sylvester layers use non-linearities in their expressions. Is it possible that a certain combination of non-linearities (at different layers) can significantly improve capacity?</li>
<li>Our paper does not provide a result for very deep Sylvester flows (e.g., $>d$ layers) with smooth non-linearities. Therefore, it is interesting to provide some insights for deep Sylvester flows.</li>
<li>A more general problem is to understand if the universal approximation property of certain class of normalizing flows holds in converting between distributions. The result is meaningful even if we assume the depth can be arbitrarily large.</li>
<li>On the other hand, it is also helpful to analyze what these normalizing flows are good at. A good example is to show that they can easily transform between distributions in a certain class, especially by an elegant construction.</li>
</ul>
<h3 id="more-details">More Details</h3>
<p>See <a href="http://proceedings.mlr.press/v108/kong20a/kong20a.pdf">our paper</a> or <a href="https://arxiv.org/abs/2006.00392">the full paper on arxiv</a>.</p><a href='https://cseweb.ucsd.edu/~z4kong'>Zhifeng Kong</a> and <a href='http://cseweb.ucsd.edu/~kamalika'>Kamalika Chaudhuri</a>Normalizing flows have received a great deal of recent attention as they allow flexible generative modeling as well as easy likelihood computation. However, there is little formal understanding of their representation power. In this work, we study some basic normalizing flows and show that (1) they may be highly expressive in one dimension, and (2) in higher dimensions their representation power may be limited.Explainable 2-means Clustering: Five Lines Proof2020-10-16T18:00:00+00:002020-10-16T18:00:00+00:00https://ucsdml.github.io//jekyll/update/2020/10/16/explain_2_means<p><strong>TL;DR:</strong> we will show <em>why</em> only one feature is enough to define a good $2$-means clustering. And we will do it using only 5 inequalities (!)
In a <a href="explain_k_means.html">previous post</a>, we explained what is an explainable clustering.</p>
<h3 id="explainable-clustering">Explainable clustering</h3>
<p>In a <a href="explain_k_means.html">previous post</a>, we discussed why explainability is important, defined it as a small decision tree, and suggested an algorithm to find such a clustering. But why the resulting clustering is any good?? We measure “good” by <a href="https://en.wikipedia.org/wiki/K-means_clustering">$k$-means cost</a>. The cost of a clustering $C$ is defined as the sum of squared Euclidean distances of each point $x$ to its center $c(x)$. Formally,
\begin{equation}
cost(C)=\sum_x \|x-c(x)\|^2,
\end{equation} the sum is over all points $x$ in the dataset.</p>
<p>In this post, we focus on the $2$-means problem, where there are only two clusters. We want to show that for every dataset there is <strong>one</strong> feature $i$ and <strong>one</strong> threshold $\theta$ such that the following simple clustering $C^{i,\theta}=(C^{i,\theta}_1,C^{i,\theta}_2)$ has a low cost:
\begin{equation}
\text{if } x_i\leq\theta \text{ then } x\in C^{i,\theta}_1 \text{ else } x\in C^{i,\theta}_2.
\end{equation}
We call such a clustering a <em>threshold cut</em>. There might be many threshold cuts that are good, bad, or somewhere in between. We want to show that there is at least one that is good (i.e., low cost). In the <a href="https://arxiv.org/abs/2002.12538">paper,</a> we prove that there is always a threshold cut, $C^{i,\theta}$, that is almost as good as the optimal clustering:
\begin{equation}
cost(C^{i,\theta})\leq4\cdot cost(opt),
\end{equation}
where $cost(opt)$ is the cost of the optimal 2-means clustering. This means that there is a simple explainable clustering $C^{i,\theta}$ that is only $4$ times worse than the optimal one. It’s independent of the dimension and the number of points. Sounds crazy, right? Let’s see how we can prove it!</p>
<h3 id="the-minimal-mistakes-threshold-cut">The minimal-mistakes threshold cut</h3>
<p>We want to compare two clusterings: the optimal clustering and the best threshold cut. The best threshold cut is hard to analyze, so we introduce an intermediate clustering: <em>the minimal-mistakes threshold cut</em>, $\widehat{C}$. Even though this clustering will not be the best threshold cut, it will be good enough. In the paper we prove that $cost(\widehat{C})$ is at most $4cost(opt)$. For simplicity, in this post, we will show a slightly worse bound of $11cost(opt)$ instead of $4cost(opt)$.</p>
<!--Let's define what the minimal-mistakes cut is. -->
<p>We define the number of mistakes of a threshold cut $C^{i,\theta}$ as the number of points $x$ that are not in the same cluster as their optimal center $c(x)$ in $C^{i,\theta}$, i.e., number of points $x$ such that<br />
\begin{equation}
sign(\theta-x_i) \neq sign(\theta-c(x)_i).
\end{equation}
The <em>minimal-mistakes clustering</em> is the threshold cut that has the minimal number of mistakes. Take a look at the next figure for an example.</p>
<figure class="image" style="text-align: center;">
<img src="/assets/2020-10-16-explain_2_means/mistakes_example.png" width="30%" style="margin: 0 auto" />
<figcaption>
Two optimal clusters are in red and blue. Centers are the stars. Split (in yellow) with one mistake. This is a minimal-mistakes threshold cut, as any threshold cut has at least $1$ mistake.
</figcaption>
</figure>
<h3 id="playing-with-cost-warm-up">Playing with cost: warm-up</h3>
<p>Before we present the proof, let’s familiarize ourselves with the $k$-means cost and explore several of its properties. It will be helpful later on!</p>
<h4 id="changing-centers">Changing centers</h4>
<p>If we change the centers of a clustering from their means (which are their optimal centers) to different centers $c=(c_1, c_2)$, then the cost can only increase. Putting this into math, denote by $cost(C,c)$ the cost of clustering $C=(C_1,C_2)$ when $c_1$ is the center of cluster $C_1$ and $c_2$ is the center of cluster $C_2$, then</p>
<p>\begin{align}
cost(C) &= \sum_{x\in C_1} \|x-mean(C_1)\|^2 + \sum_{x\in C_2} \|x-mean(C_2)\|^2 \newline &\leq \sum_{x\in C_1} \|x-c_1\|^2 + \sum_{x\in C_2} \|x-c_2\|^2 = cost(C,c).
\end{align}
What if we further want to change the centers from some arbitrary centers $(c_1, c_2)$ to other arbitrary centers $(m_1, m_2)$? How does the cost change? Can we bound it? To our rescue comes the (almost) triangle inequality that states that for any two vectors $x,y$:
\begin{equation}
\|x+y\|^2 \leq 2\|x\|^2+2\|y\|^2.
\end{equation}
This implies that the cost of changing the centers from $c=(c_1, c_2)$ to $m=(m_1, m_2)$ is bounded by
\begin{equation}
cost(C,c)\leq 2cost(C,m)+2|C_1|\|c_1-m_1\|^2+2|C_2|\|c_2-m_2\|^2.
\end{equation}</p>
<h4 id="decomposing-the-cost">Decomposing the cost</h4>
<p>The cost can be easily decomposed with respect to the data points and the features. Let’s start with the data points. For any partition of the points in $C$ to $S_1$ and $S_2$, the cost can be rewritten as
\begin{equation}
cost(C,c)=cost(C \cap S_1,c)+cost(C \cap S_2,c).
\end{equation}
The cost can also be decomposed with respect to the features, because we are using the squared Euclidean distance. To be more specific, the cost incur by the $i$-th feature is $cost_i(C,c)=\sum_{x}(x_i-c(x)_i)^2,$ and the total cost is equal to
\begin{equation}
cost(C,c)=\sum_i cost_i(C,c).
\end{equation}
If the last equation is unclear just recall the definition of the cost ($c(x$) is the center of a point $x$):
\begin{equation}
cost(C,c)=\sum_{x}\|x-c(x)\|^2=\sum_i\sum_{x}(x_i-c(x)_i)^2=\sum_icost_i(C,c).
\end{equation}</p>
<h3 id="the-5-line-proof">The 5-line proof</h3>
<p>Now we are ready to prove that $\widehat{C}$ is only a constant factor worse than the optimal $2$-means clustering:
\begin{equation}
cost(\widehat{C})\leq 11\cdot cost(opt).
\end{equation}</p>
<p>To prove that the minimal-mistakes threshold cut $\widehat{C}$ gives a low-cost clustering, we will do something that might look strange at first. We analyze the quality of this clustering $\widehat{C}$ with the optimal centers of the optimal clustering. And not the optimal centers for $\widehat{C}$. This step will only increase the cost, so why are we doing it — because it will ease our analysis, and if there are not many mistakes, then the centers do not change much, like in the previous figure. So it’s not much of an increase. So, here comes the first step — change the centers of $\widehat{C}$ to the optimal centers $c^*=(mean(C^*_1),mean(C^*_2))$. Recall from the warm-up that this can only increase the cost:
\begin{equation}
cost(\widehat{C})\leq cost(\widehat{C},c^{*}) \quad (1)
\end{equation}
Next we use one of the decomposition properties of the cost. We partition the dataset into the set of points that are correctly labeled, $X^{cor}$, and those that are not, $X^{wro}$.</p>
<figure class="image" style="text-align: center;">
<img src="/assets/2020-10-16-explain_2_means/mistakes_example_wrong.png" width="30%" style="margin: 0 auto" />
<figcaption>
The same dataset and split as before. Point with a grey circle is in the wrong cluster and is the only member in $X^{wro}$. All other points have the same cluster assignment as the optimal clustering and are in $X^{cor}$.
</figcaption>
</figure>
<p>Thus, we can rewrite the last term as
\begin{equation}
cost(\widehat{C},c^{*})=cost(\widehat{C}\cap X^{cor},c^{*})+cost(\widehat{C}\cap X^{wro},c^{*}) \quad (2)
\end{equation}</p>
<p>Let’s look at this sum. The first term contains all the points that have their correct center in $c^*$ (which is either $mean(C^*_1)$ or $mean(C^*_2)$). Hence, the first term in (2) is easy to bound: it’s at most $cost(opt)$. So from now on, we focus on the second term.</p>
<p>In the second term, all points are in $X^{wro}$, which means they were assigned to the incorrect optimal center. So let’s change the centers once more, so that $X^{wro}$ will have the correct centers. The correct centers of $X^{wro}$ are the same centers $c^*$, but the order is reversed, i.e., all points assigned to center $mean(C^*_1)$ are now assigned to $mean(C^*_2)$ and vice versa. Using the “changing centers” property of the cost we discussed earlier, we have <!--, the second term in (2) is at most--></p>
<p>\begin{equation}
cost(\widehat{C},c^{*}) \leq 3cost(opt)+2|X^{wro}|\cdot\|c^{*}_1-c^{*}_2\|^2 \quad (3)
\end{equation}</p>
<p>Now we’ve reached the main step in the proof. We show that the second term in (3) is bounded by $8cost(opt)$. We first decompose $cost(opt)$ using the features. Then, all we need to show is that:</p>
<p>\begin{equation}
cost_i(opt)\geq\left(\frac{|c^{*}_{1,i}-c^{*}_{2,i}|}{2}\right)^2|X^{wro}| \quad (4)
\end{equation}</p>
<p>The trick is, for each feature, to focus on the threshold cut defined by the middle point between the two optimal centers. Since $\widehat{C}$ is the minimal-mistakes clustering we know that in every threshold cut there are at least $|X^{wro}|$ mistakes. Each mistake contributes at least half the distance between the two centers.</p>
<figure class="image" style="text-align: center;">
<img src="/assets/2020-10-16-explain_2_means/IMM_blog_pic_4.png" width="30%" style="margin: 0 auto" />
<figcaption>
Proving step $4.$ Projecting to feature $i$. Points in blue belong to the first cluster, and in red to the second. We focus on the cut that is the mid-point between the two optimal centers.
</figcaption>
</figure>
<p>This figure shows how to prove step (4). We see that there is $1$ mistake, which is the minimum possible. This means that even the optimal clustering must pay for at least half the distance between the centers for each of these mistakes. This gives us a lower bound on $cost_i(opt)$ in this feature. Then we can sum over all the features to see that the second term of (3) is at most $8cost(opt)$, which is what we wanted. <!--Since the whole expression in (3) is at most $10cost(opt)$, and we lose another $cost(opt)$ from the first term of (2), we can put these together to get-->
<!--Summing everything together we achieve our goal:-->
Putting everything together, we get exactly what we wanted to prove in this post:
\begin{equation}
cost(\widehat{C})\leq1 1\cdot cost(opt) \quad (5)
\end{equation}
<!--That's it!--></p>
<h3 id="epilogue-improvements">Epilogue: improvements</h3>
<p>The bound that we got, $11$, is not the best possible. With more tricks we can get a bound of $4$. One of them is using Hall’s theorem. Similar ideas provide a $2$-approximation to the optimal $2$-medians clustering as well.
To complement our upper bounds, we also prove lower bounds showing that any threshold cut must incur almost $3$-approximation for $2$-means and almost $2$-approximation for $2$-medians. You can read all about it in our <a href="https://proceedings.icml.cc/paper/2020/file/8e489b4966fe8f703b5be647f1cbae63-Paper.pdf">paper</a>.</p><a href='https://sites.google.com/view/michal-moshkovitz'>Michal Moshkovitz</a>, <a href='mailto:navefrost@mail.tau.ac.il'>Nave Frost</a>, <a href='https://sites.google.com/site/cyrusrashtchian/'>Cyrus Rashtchian</a>In a previous post, we discussed tree-based clustering and how to develop explainable clustering algorithms with provable guarantees. Now we will show why only one feature is enough to define a good 2-means clustering. And we will do it using only 5 inequalities (!)Explainable k-means Clustering2020-10-16T17:00:00+00:002020-10-16T17:00:00+00:00https://ucsdml.github.io//jekyll/update/2020/10/16/explain_k_means<p><strong>TL;DR:</strong>
Explainable AI has gained a lot of interest in the last few years, but effective methods for unsupervised learning are scarce. And the rare methods that do exist do not have provable guarantees. We present a new algorithm for explainable clustering that is provably good for $k$-means clustering — the Iterative Mistake Minimization (IMM) algorithm. Specifically, we want to build a clustering defined by a small decision tree. Overall, this post summarizes our new paper: <a href="https://arxiv.org/pdf/2002.12538.pdf">Explainable $k$-Means and $k$-Medians clustering</a>.</p>
<h3 id="explainability-why">Explainability: why?</h3>
<p>Machine learning models are mostly “black box”. They give good results, but their reasoning is unclear. These days, machine learning is entering fields like healthcare (e.g., for a better understanding of <a href="https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6543980/#:~:text=In%20the%20medical%20field%2C%20clustering,in%20labeled%20and%20unlabeled%20datasets.&text=The%20aim%20is%20to%20provide,AD%20based%20on%20their%20similarity.">Alzheimer’s Disease</a> and <a href="https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0118453#sec013">Breast Cancer</a>), transportation, or law. In these fields, quality is not the only objective. No matter how well a computer is making its predictions, we can’t even imagine blindly following computer’s suggestion. Can you imagine blindly medicating or performing a surgery on a patient just because a computer said so? Instead, it would be much better to provide insight into what parts of the data the algorithm used to make its prediction.</p>
<h3 id="tree-based-explainable-clustering">Tree-based explainable clustering</h3>
<!--Despite the popularity of explainability, there is limited work in unsupervised learning. To remedy it, -->
<p>We study a prominent problem in unsupervised learning, $k$-means clustering. We are given a dataset, and the goal is to partition it to $k$ clusters such that the <a href="https://en.wikipedia.org/wiki/K-means_clustering">$k$-means cost</a> is minimal. The cost of a clustering $C=(C^1,\ldots,C^k)$ is the sum of all points from their optimal centers, $mean(C^i)$:</p>
<p>\[cost(C)=\sum_{i=1}^k\sum_{x\in C^i} \lVert x-mean(C^i)\rVert ^2.\]</p>
<p>For any cluster, $C^i$, one possible explanation of this cluster is $mean(C^i)$. In a low-cost clustering, the center is close to its points, and they are close to each other. For example, see the next figure.</p>
<figure class="image" style="text-align: center;">
<img src="/assets/2020-10-16-explain_k_means/intro_IMM_blog_pic_1.png" width="40%" style="margin: 0 auto" />
<figcaption>
Near optimal 5-means clustering
</figcaption>
</figure>
<p>Unfortunately, this explanation is not as useful as it could be. The centers themselves may depend on all the data points and all the features in a complicated way. We instead aim to develop a clustering method that is explainable by design. To explain why a point is in a cluster, we will only need to look at small number of features, and we will just evaluate a threshold for each feature one by one. This allows us to extract information about which features cause a point to go to one cluster compared to another. This method also means that we can derive an explanation that does not depend on the centers.</p>
<p>More formally, at each step we test if $x_i\leq \theta$ or not, for some feature $i$ and threshold $\theta$. We call this test a <strong>split</strong>. According to the test’s result, we decide on the next step. In the end, the algorithm returns the cluster identity. This procedure is exactly a decision tree where the leaves correspond to clusters.</p>
<p>Importantly, for the tree to be explainable it should be <strong>small</strong>. The smallest decision tree has $k$ leaves since each cluster must appear in at least one leaf. We call a clustering defined by a decision tree with $k$ leaves a <strong>tree-based explainable clustering</strong>. See the next tree for an illustration.</p>
<p align="center">
<tr>
<td> <img src="/assets/2020-10-16-explain_k_means/intro_IMM_blog_pic_2.png" width="40%" style="margin: 0 auto" /> </td>
<td> <img src="/assets/2020-10-16-explain_k_means/intro_IMM_blog_pic_3.png" width="40%" style="margin: 0 auto" /> </td>
</tr>
</p>
<!--
{:refdef: style="text-align: center;"}
<figure class="image">
<img src="/assets/2020-06-06/intro_IMM_blog_pic_2.png" width="40%" style="margin: 0 auto">
<figcaption>
Decision tree
</figcaption>
</figure>
{:refdef}
{:refdef: style="text-align: center;"}
<figure class="image">
<img src="/assets/2020-06-06/intro_IMM_blog_pic_3.png" width="40%" style="margin: 0 auto">
<figcaption>
Geometric representation of the decision tree
</figcaption>
</figure>
{:refdef}
-->
<p>On the left, we see a decision tree that defines a clustering with $5$ clusters. On the right, we see the geometric representation of this decision tree. We see that the decision tree imposes a partition to $5$ clusters aligned to the axis. The clustering looks close to the optimal clustering that we started with. Which is great. But can we do it for all datasets? How?</p>
<p>Several algorithms are trying to find a tree-based explainable clustering like <a href="https://link.springer.com/chapter/10.1007/11362197_5">CLTree</a> and <a href="https://www.researchgate.net/profile/Ricardo_Fraiman/publication/47744381_Clustering_using_Unsupervised_Binary_Trees_CUBT/links/09e41508aeaf39a453000000/Clustering-using-Unsupervised-Binary-Trees-CUBT.pdf">CUBT</a>. But we are the first to give formal guarantees. We first need to define the quality of an algorithm. It’s common that unsupervised learning problems are <a href="http://cseweb.ucsd.edu/~dasgupta/papers/kmeans.pdf">NP-hard</a>. Clustering is no exception. So it is common to settle for an approximated solution. A bit more formal, an algorithm that returns a tree-based clustering $T$ is an <em>$a$-approximation</em> if $cost(T)\leq a\cdot cost(opt),$ where $opt$ is the clustering that minimizes the $k$-means cost.</p>
<h3 id="general-scheme">General scheme</h3>
<p>Many supervised learning algorithms learn a decision tree, can we use one of them here? Yes, after we transform the problem into a supervised learning problem! How might you ask? We can use any clustering algorithm that will return a good, but not explainable clustering. This will form the labeling. Next, we can use a supervised algorithm that learns a decision tree. Let’s summarize these three steps:</p>
<ol>
<li>Find a clustering using some clustering algorithm</li>
<li>Label each example according to its cluster</li>
<li>Call a supervised algorithm that learns a decision tree</li>
</ol>
<p>Which algorithm can we use in step 3? Maybe the popular ID3 algorithm?</p>
<h3 id="can-we-use-the-id3-algorithm">Can we use the ID3 algorithm?</h3>
<p>Short answer: no.</p>
<p>One might hope that in step 3, in the previous scheme, the known <a href="https://link.springer.com/content/pdf/10.1007/BF00116251.pdf">ID3</a> algorithm can be used (or one of its variants like <a href="https://link.springer.com/article/10.1007/BF00993309">C4.5</a>). We will show that this does not work. There are datasets where ID3 will perform poorly. Here is an example:</p>
<figure class="image" style="text-align: center;">
<img src="/assets/2020-10-16-explain_k_means/intro_IMM_blog_pic_4.png" width="40%" style="margin: 0 auto" />
<figcaption>
ID3 performs poorly on this dataset
</figcaption>
</figure>
<p>The dataset is composed of three clusters, as you can see in the figure above. Two large clusters (0 and 1 in the figure) have centers (-2, 0) and (2, 0) accordingly and small noise. The third cluster (2 in the figure) is composed of only two points that are very, very (very) far away from clusters 0 and 1. Given these data, ID3 will prefer to maximize the information gain and split between clusters 0 and 1. Recall that the final tree has only three leaves. This means that in the final tree, one point in cluster 2 must be with cluster 0 or cluster 1. Thus the cost is enormous.
To solve this problem, we design a new algorithm called <a href="https://proceedings.icml.cc/paper/2020/file/8e489b4966fe8f703b5be647f1cbae63-Paper.pdf"><em>Iterative Mistake Minimization (IMM)</em></a>.</p>
<h3 id="imm-algorithm-for-explainable-clustering">IMM algorithm for explainable clustering</h3>
<p>We learned that the ID3 algorithm cannot be used in step 3 at the general scheme. Before we give up on this scheme, can we use a different decision-tree algorithm? Well, since we wrote this post, you probably know the answer: there is such an algorithm, the IMM algorithm.</p>
<p>We build the tree greedily from top to bottom. Each step we take the split (i.e., feature and threshold) that minimizes a new parameter called a <strong>mistake</strong>. A point $x$ is a mistake for node $u$ if $x$ and its center $c(x)$ reached $u$ and then separated by $u$’s split. See the next figure for an example of a split with one mistake.</p>
<figure class="image" style="text-align: center;">
<img src="/assets/2020-10-16-explain_k_means/mistakes_example.png" width="40%" style="margin: 0 auto" />
<figcaption>
Split (in yellow) with one mistake. Two optimal clusters are in red and blue. Centers are the stars.
</figcaption>
</figure>
<!--For another example of the mistakes concept, let's go back to the previous dataset where ID3 failed. Focus on the first split again. The ID3 split has one mistake since one of the points in cluster $2$ will be separated from its center. On the other hand, the horizontal split has $0$ mistakes: the two large clusters will go with their centers to one side of the tree, and the small cluster will go with its center to the other side of the tree. -->
<p>To summarize, the high-level description of the IMM algorithm:
<!--<center>
<span style="font-family:Papyrus; font-size:2em;align-self: center;">As long as there is more than one center
<br> find the split with minimal number of mistakes</span>
</center>
--></p>
<center>
<span style="font-size:larger;">
As long as there is more than one center
<br /> find the split with minimal number of mistakes
</span>
</center>
<p> </p>
<!--What if there are no mistakes.
The main definition that we need is a mistake:
Creare a different figure that explains a mistake with small number of points
-->
<!--
<center>
<span style="font-family:Papyrus; font-size:2em;align-self: center;">If a point and its center diverge,
<br> then it counts as a mistake</span>
</center>
<div class="definition"> [mistake at node $u$].
If a point and its center end up at different leafs, then it counts as a mistake.
</div>
... Explain what is a split early on ...
-->
<!---
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">IMM</span><span class="p">(</span><span class="n">points</span><span class="p">,</span> <span class="n">centers</span><span class="p">):</span>
<span class="n">node</span> <span class="o">=</span> <span class="n">new</span> <span class="n">Node</span><span class="p">()</span>
<span class="k">if</span> <span class="o">|</span><span class="n">centers</span><span class="o">|</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
<span class="n">i</span><span class="p">,</span> <span class="n">theta</span> <span class="o">=</span> <span class="n">find_split</span><span class="p">(</span><span class="n">points</span><span class="p">,</span> <span class="n">centers</span><span class="p">)</span>
<span class="n">node</span><span class="p">.</span><span class="n">condition</span> <span class="o">=</span> <span class="s">'x_i <= theta'</span>
<span class="n">points_left_mask</span> <span class="o">=</span> <span class="n">points</span><span class="p">[:,</span><span class="n">i</span><span class="p">]</span> <span class="o"><=</span> <span class="n">theta</span>
<span class="n">centers_left_mask</span> <span class="o">=</span> <span class="n">centers</span><span class="p">[:,</span><span class="n">i</span><span class="p">]</span> <span class="o"><=</span> <span class="n">theta</span>
<span class="n">node</span><span class="p">.</span><span class="n">left</span> <span class="o">=</span> <span class="n">IMM</span><span class="p">(</span><span class="n">points</span><span class="p">[</span><span class="n">points_left_mask</span><span class="p">],</span> <span class="n">centers</span><span class="p">[</span><span class="n">centers_left_mask</span><span class="p">])</span>
<span class="n">node</span><span class="p">.</span><span class="n">right</span> <span class="o">=</span> <span class="n">IMM</span><span class="p">(</span><span class="n">points</span><span class="p">[</span><span class="o">~</span><span class="n">points_left_mask</span><span class="p">],</span> <span class="n">centers</span><span class="p">[</span><span class="o">~</span><span class="n">centers_left_mask</span><span class="p">])</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">node</span><span class="p">.</span><span class="n">label</span> <span class="o">=</span> <span class="n">centers</span>
<span class="k">return</span> <span class="n">node</span>
<span class="k">def</span> <span class="nf">find_split</span><span class="p">(</span><span class="n">points</span><span class="p">,</span> <span class="n">centers</span><span class="p">):</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">d</span><span class="p">):</span>
<span class="n">l</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">centers</span><span class="p">[:,</span><span class="n">i</span><span class="p">])</span>
<span class="n">r</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">centers</span><span class="p">[:,</span><span class="n">i</span><span class="p">])</span>
<span class="n">i</span><span class="p">,</span><span class="n">theta</span> <span class="o">=</span> <span class="n">argmin_</span><span class="p">{</span><span class="n">i</span><span class="p">,</span><span class="n">l</span> <span class="o"><=</span> <span class="n">theta</span> <span class="o"><</span> <span class="n">r</span><span class="p">}</span> <span class="n">mistakes</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">theta</span><span class="p">)</span>
<span class="k">return</span> <span class="n">i</span><span class="p">,</span><span class="n">theta</span></code></pre></figure>
-->
<p>Here is an illustration of the IMM algorithm. We use $k$-means++ with $k=5$ to find a clustering for our dataset. Each point is colored with its cluster label. At each node in the tree, we choose a split with a minimal number of mistakes. We stop where each of the $k=5$ centers is in its own leaf. This defines the explainable clustering on the left.</p>
<center>
<img src="/assets/2020-10-16-explain_k_means/imm_example_slow.gif" width="600" height="320" />
</center>
<p>The algorithm is guaranteed to perform well. For any dataset. See the next theorem.</p>
<div class="theorem">
IMM is an $O(k^2)$-approximation to the optimal $k$-means clustering.
</div>
<p>This theorem shows that we can always find a small tree, with $k$ leaves, such that the tree-based clustering is only $O(k^2)$ times worse in terms of the cost. IMM efficiently find this explainable clustering. Importantly, this approximation is independent of the dimension and the number of points. A proof for the case $k=2$ will appear in a <a href="explain_2_means.html">follow-up post</a>, and you can read the proof for general $k$ in the paper. Intuitively, we discovered that the number of mistakes is a good indicator for the $k$-means cost, and so, minimizing the number of mistakes is an effective way to find a low-cost clustering. <!-- Surprisingly, we can also use a tree with $k$ leaves, which means that IMM produces an explainable clustering.--></p>
<h4 id="running-time">Running Time</h4>
<p>What is the running time of the IMM algorithm? With an efficient implementation, using dynamic programming, the running time is $O(kdn\log(n)).$ Why? For each of the $k-1$ inner nodes and each of the $d$ features, we can find the split that minimizes the number of mistakes for this node and feature, in time $O(n\log(n)).$</p>
<p>For $2$-means one can do better than running IMM: going over all possible $(n-1)d$ cuts and find the best one. The running time is $O(nd^2+nd\log(n))$.</p>
<h3 id="results-summary">Results Summary</h3>
<p>In each cell in the following table, we write the approximation factor. We want this value to be small for the upper bounds and large for the lower bounds. In $2$-medians, the upper and lower bounds are pretty tight, about $2$. But, there is a large gap for $k$-means and $k$-median: the lower bound is $\log(k)$, while the upper bound is $\mathsf{poly}(k)$.</p>
<center>
<table style="text-align: center">
<thead>
<tr>
<th></th>
<th colspan="2" style="text-align: center">$k$-medians</th>
<th colspan="2" style="text-align: center">$k$-means</th>
</tr>
<tr>
<th></th>
<th> $k=2$ </th>
<th> $k>2$ </th>
<th> $k=2$ </th>
<th> $k>2$ </th>
</tr>
</thead>
<tbody>
<tr>
<td> <strong>Lower</strong> </td>
<td> $2-\frac1d$ </td>
<td> $\Omega(\log k)$ </td>
<td> $3\left(1-\frac1d\right)^2$ </td>
<td> $\Omega(\log k)$ </td>
</tr>
<tr>
<td> <strong>Upper</strong> </td>
<td> $2$ </td>
<td> $O(k)$ </td>
<td> $4$ </td>
<td> $O(k^2)$ </td>
</tr>
</tbody>
</table>
</center>
<h3 id="whats-next">What’s next</h3>
<ol>
<li>IMM exhibits excellent results in practice on many datasets, see <a href="https://arxiv.org/abs/2006.02399">this</a>. It’s running time is comparable to KMeans implemented in sklearn. We implemented the IMM algorithm, it’s <a href="https://github.com/navefr/ExKMC">here</a>. Try it yourself.</li>
<li>We plan to have several posts on explainable clusterings, here is the <a href="explain_2_means.html">second</a> in the series, stay tuned for more!</li>
<li>In a follow-up work, we explore the tradeoff between explainability and accuracy. If we allow a slightly larger tree, can we get a lower cost? We introduce the <a href="https://arxiv.org/abs/2006.02399">ExKMC</a>, “Expanding Explainable $k$-Means Clustering”, algorithm that builds on IMM.</li>
<li>Found cool applications of IMM? Let us know!</li>
</ol><a href='https://sites.google.com/view/michal-moshkovitz'>Michal Moshkovitz</a>, <a href='mailto:navefrost@mail.tau.ac.il'>Nave Frost</a>, <a href='https://sites.google.com/site/cyrusrashtchian/'>Cyrus Rashtchian</a>Popular algorithms for learning decision trees can be arbitrarily bad for clustering. We present a new algorithm for explainable clustering that has provable guarantees --- the Iterative Mistake Minimization (IMM) algorithm. This algorithm exhibits good results in practice. It's running time is comparable to KMeans implemented in sklearn. So our method gives you explanations basically for free. Our code is available on github.Towards Physics-informed Deep Learning for Turbulent Flow Prediction2020-08-23T00:00:00+00:002020-08-23T00:00:00+00:00https://ucsdml.github.io//jekyll/update/2020/08/23/TF-Net<h3 id="prediction-visualization">Prediction Visualization</h3>
<p>We propose a novel hybrid model for turbulence prediction, $\texttt{TF-Net}$, that unifies a popular <a href="https://en.wikipedia.org/wiki/Computational_fluid_dynamics">Computational fluid dynamics (CFD)</a> technique, RANS-LES coupling, with custom-designed U-net. The following two videos show the ground truth and the predicted U (left) and V (right) velocity fields from $\texttt{TF-Net}$ and three best baselines. We see that the predictions by $\texttt{TF-Net}$ are the closest to the target based on the shape and the frequency of the motions. Baselines generate smooth predictions and miss the details of small scale motion.</p>
<div style="text-align: center;">
<img src="/assets/2020-08-23-TF-Net/U_prediction.gif" width="49%" style="margin: 0 auto" />
<img src="/assets/2020-08-23-TF-Net/V_prediction.gif" width="49%" style="margin: 0 auto" />
</div>
<p><br /></p>
<h3 id="introduction">Introduction</h3>
<p>Modeling the spatiotemporal dynamics over a wide range of space and time scales is a fundamental task in science, especially atmospheric science, marine science and aerodynamics. <a href="https://en.wikipedia.org/wiki/Computational_fluid_dynamics">Computational fluid dynamics (CFD)</a> is at the heart of climate modeling and has direct implications for understanding and predicting climate change. Recently, deep learning have demonstrated great success in the <a href="https://www.nature.com/articles/s41586-019-0912-1">automation, acceleration, and streamlining of highly compute-intensive workflows for science</a>. We hope deep learning can accelerate the turbulence simulation since the current CFD is purely physics-based and computationally-intensive, requiring significant computational resources and expertise.</p>
<div style="text-align: center;">
<img src="/assets/2020-08-23-TF-Net/imgs.png" width="90%" style="margin: 0 auto" />
</div>
<p><br />
But purely data-driven methods are mainly statistical with no underlying physical knowledge incorporated, and are yet to be proven to be successful in capturing and predicting accurately the complex physical systems. Incorporating physics knowledge into deep learning models can improve not only prediction accuracy, but more importantly, physical consistency. Thus, developing deep learning methods that can incorporate physical laws in a systematic manner is a key element in advancing AI for physical sciences.</p>
<p><a href="https://uknowledge.uky.edu/me_textbooks/2/">Computational techniques</a> are at the core of present-day turbulence investigations, which are a branch of fluid mechanics that uses numerical method to analyze and predict fluid flows. In physics, people use the following <a href="https://en.wikipedia.org/wiki/Navier%E2%80%93Stokes_equations">Navier–Stokes equations</a> to describe the motion of viscous fluid dynamics.</p>
\[\nabla \cdot \pmb{w} = 0 \qquad\qquad\qquad\qquad\qquad\qquad\qquad \text{Continuity Equation}\]
\[\frac{\partial \pmb{w}}{\partial t} + (\pmb{w} \cdot \nabla) \pmb{w} = -\frac{1}{\rho_0} \nabla p + \nu \nabla^2 \pmb{w} + f \quad\text{Momentum Equation}\]
\[\frac{\partial T}{\partial t} + (\pmb{w} \cdot \nabla) T = \kappa \nabla^2 T \qquad\qquad\qquad\quad \text{Temperature Equation}\]
<p>where $\pmb{w}(t)$ is the vector velocity field of the flow, which is what we want to predict. $p$ and $T$ are pressure and temperature respectively, $\kappa$ is the coefficient of heat conductivity, $\rho_0$ is density at temperature at the beginning, $\alpha$ is the coefficient of thermal expansion, $\nu$ is the kinematic viscosity, $f$ the body force that is due to gravity.</p>
<p><br /></p>
<h3 id="turbulent-flow-net">Turbulent-Flow Net</h3>
<p>For turbulent flows, the range of length scales and complexity of phenomena involved in turbulence make <a href="https://en.wikipedia.org/wiki/Direct_numerical_simulation">Direct Numerical Simulation (DNS)</a> approaches prohibitively expensive. Great emphasis was then placed on the alternative approaches including Large-Eddy Simulation (LES), Reynolds-averaged Navier Stokes (RANS) as well as <a href="https://link.springer.com/article/10.1007/s10494-017-9828-8">Hybrid RANS-LES Coupling</a> that combines both RANS and LES approaches in order to take advantage of both methods. These methods decompose the fluid flow into different scales in order to directly simulate large scales while model small ones.</p>
<p>Hybrid RANS-LES Coupling decomposes the flow velocity into three scales: mean flow, resolved fluctuations and unresolved fluctuations. It applies the spatial filtering operator $S$ and the temporal average operator $T$ sequentially.</p>
\[\pmb{w^*}(\pmb{x},t) = S \ast\pmb{w} = \sum_{\pmb{\xi}} S(\pmb{x}|\pmb{\xi})\pmb{w}(\pmb{\xi},t)\]
\[\pmb{\bar{w}}(\pmb{x},t) = T \ast \pmb{w^*} = \frac{1}{n}\sum_{s = t-n}^tT(s) \pmb{w^*} (\pmb{x}, s)\]
<p>then $\pmb{\tilde{w}}$ can be defined as the difference between $\pmb{w^*}$ and $\pmb{\bar{w}}$:</p>
\[\pmb{\tilde{w}} = \pmb{w^*} - \pmb{\bar{w}}, \quad \pmb{w'} = \pmb{w} - \pmb{w^{*}}\]
<p>Finally we can have the three-level decomposition of the velocity field.</p>
<p>\begin{equation}
\pmb{w} = \pmb{\bar{w}} + \pmb{\tilde{w}} + \pmb{w’}
\end{equation}</p>
<p>The figure below shows this three-level decomposition in wavenumber space. $k$ is the wavenumber, the spatial frequency in the Fourier domain. $E(k)$ is the energy spectrum describing how much kinetic energy is contained in eddies with wavenumber $k$. Small $k$ corresponds to large eddies that contain most of the energy. The slope of the spectrum is negative and indicates the transfer of energy from large scales of motion to the small scales.</p>
<div style="text-align: center;">
<img src="/assets/2020-08-23-TF-Net/decompose.png" width="40%" style="margin: 0 auto" />
</div>
<p><br />
Inspired by the hybrid RANS-LES Coupling, we propose a hybrid deep learning framework, $\texttt{TF-Net}$, based on the multilevel spectral decomposition. Specifically, we decompose the velocity field into three scales using the spatial filter $S$ and the temporal filter $T$. Unlike traditional CFD, both filters in $\texttt{TF-Net}$ are trainable neural networks. The motivation for this design is to explicitly guide the DL model to learn the non-linear dynamics of both large and small eddies. We design three identical convolutional encoders to encode the three scale components separately and use a shared convolutional decoder to learn the interactions among these three components and generate the final prediction. The figure below shows the overall architecture of our hybrid model $\texttt{TF-Net}$.</p>
<div style="text-align: center;">
<img src="/assets/2020-08-23-TF-Net/model.png" width="98%" style="margin: 0 auto" />
</div>
<p><br /></p>
<p>Since the turbulent flow under investigation has zero divergence, we include $\Vert\nabla \cdot \pmb{w}\Vert^2$ as a regularizer to constrain the predictions, leading to a constrained TF-Net, $\texttt{Con TF-Net}$.</p>
<p><br />
<br /></p>
<h3 id="results">Results</h3>
<p>We compare our model with four purely data-driven deep learning models, including <a href="https://arxiv.org/abs/1512.03385">$\texttt{ResNet}$</a>, <a href="https://arxiv.org/abs/1506.04214">$\texttt{ConvLSTM}$</a>, <a href="https://arxiv.org/abs/1505.04597">$\texttt{U-net}$</a> and <a href="https://arxiv.org/abs/1406.2661">$\texttt{GAN}$</a>, and two hybrid physics-informed models, including <a href="https://arxiv.org/abs/1801.06637">$\texttt{DHPM}$</a> and <a href="https://arxiv.org/abs/1711.07970">$\texttt{SST}$</a>. All the models trained to make one step ahead prediction given the historic frames and we use them autoregressively to generate multi-step forecasts.</p>
<p>$\textbf{Accuracy}$ The following figure show the growth of RMSE with prediction horizon up to 60 time steps ahead. We can see that $\texttt{TF-Net}$ consistently outperforms all baselines, and constraining it with divergence free regularizer can further improve the performance.</p>
<div style="text-align: center;">
<img src="/assets/2020-08-23-TF-Net/rmse_horizon.png" width="55%" style="margin: 0 auto" />
</div>
<p><br /></p>
<p>$\textbf{Physical Consistency}$ The left figure below is the averages of absolute divergence over all pixels at each prediction step and the right figure below is the energy spectrum curves. $\texttt{TF-Net}$ predictions are in fact much closer to the target even without additional divergence free constraint, which suggests that $\texttt{TF-Net}$ can generate predictions that are physically consistent with the ground truth.</p>
<div style="text-align: center;">
<img src="/assets/2020-08-23-TF-Net/divergence.png" width="48%" style="margin: 0 auto" />
<img src="/assets/2020-08-23-TF-Net/spec_ci_square.png" width="48%" style="margin: 0 auto" />
</div>
<p><br /></p>
<p>$\textbf{Efficiency}$ This figure shows the average time to produce one 64 × 448 2d velocity field for all models on single V100 GPU. We can see that $\texttt{TF-net}$, $\texttt{U_net}$ and $\texttt{GAN}$ are faster than the numerical Lattice Boltzmann method. $\texttt{TF-Net}$ will show greater advantage of speed on higher resolution data.</p>
<div style="text-align: center;">
<img src="/assets/2020-08-23-TF-Net/avg_time.png" width="60%" style="margin: 0 auto" />
</div>
<p><br />
$\textbf{Ablation Study}$ We also perform an ablation study to understand each component of $\texttt{TF-Net}$ and investigate whether the model has actually learned the flow with different scales. During inference, we applied each small U-net in $\texttt{TF-Net}$ with the other two encoders removed to the entire input domain. The video below includes $\texttt{TF-Net}$ predictions and the outputs of each small $\texttt{U-net}$ while the other two encoders are zeroed out. We observe that the outputs of each small $\texttt{U-net}$ are the flow with different scales, which demonstrates that $\texttt{TF-Net}$ can learn multi-scale behaviors.</p>
<div style="text-align: center;">
<img src="/assets/2020-08-23-TF-Net/Ablation_Study.gif" width="70%" style="margin: 0 auto" />
</div>
<p><br />
<br /></p>
<h3 id="conclusion-and-future-work">Conclusion and Future Work</h3>
<p>We presented a novel hybrid deep learning model, $\texttt{TF-Net}$, that unifies representation learning and turbulence simulation techniques. $\texttt{TF-Net}$ exploits the multi-scale behavior of turbulent flows to design trainable scale-separation operators to model different ranges of scales individually. We provide exhaustive comparisons of $\texttt{TF-Net}$ and baselines and observe significant improvement in both the prediction error and desired physical quantifies, including divergence, turbulence kinetic energy and energy spectrum. Future work includes extending these techniques to very high-resolution, 3D turbulent flows and incorporating additional physical variables, such as pressure and temperature, and additional physical constraints, such as conservation of momentum, to improve the accuracy and faithfulness of deep learning models.</p>
<h3 id="more-details">More Details</h3>
<h4 id="see-our-paper-or-our-repository">See <a href="https://arxiv.org/abs/1911.08655">our paper</a> or our <a href="https://github.com/Rose-STL-Lab/Turbulent-Flow-Net">repository</a>.</h4><a href='mailto:ruw020@ucsd.edu'>Rui Wang</a>, <a href='mailto:kkashinath@lbl.gov'>Karthik Kashinath</a>, <a href='mailto:mmustafa@lbl.gov'>Mustafa Mustafa</a>, <a href='mailto:aalbert@lbl.gov'>Adrian Albert</a> and <a href='mailto:roseyu@eng.ucsd.edu'>Rose Yu</a>While deep learning has shown tremendous success in a wide range of domains, it remains a grand challenge to incorporate physical principles in a systematic manner to the design, training, and inference of such models. In this paper, we aim to predict turbulent flow by learning its highly nonlinear dynamics from spatiotemporal velocity fields of large-scale fluid flow simulations of relevance to turbulence modeling and climate modeling. We adopt a hybrid approach by marrying two well-established turbulent flow simulation techniques with deep learning. Specifically, we introduce trainable spectral filters in a coupled model of Reynolds-averaged Navier-Stokes (RANS) and Large Eddy Simulation (LES), followed by a specialized U-net for prediction. Our approach, which we call turbulent-Flow Net (TF-Net), is grounded in a principled physics model, yet offers the flexibility of learned representations. We compare our model, TF-Net, with state-of-the-art baselines and observe significant reductions in error for predictions 60 frames ahead. Most importantly, our method predicts physical fields that obey desirable physical characteristics, such as conservation of mass, whilst faithfully emulating the turbulent kinetic energy field and spectrum, which are critical for accurate prediction of turbulent flows.How to Detect Data-Copying in Generative Models2020-08-03T19:00:00+00:002020-08-03T19:00:00+00:00https://ucsdml.github.io//jekyll/update/2020/08/03/how-to-detect-data-copying-in-generative-models<p>In our <a href="https://arxiv.org/abs/2004.05675">AISTATS 2020 paper</a>, professors <a href="https://cseweb.ucsd.edu/~kamalika/">Kamalika Chaudhuri</a>, <a href="https://cseweb.ucsd.edu/~dasgupta/">Sanjoy Dasgupta</a>, and I propose some new definitions and test statistics for conceptualizing and measuring overfitting by generative models.</p>
<p>Overfitting is a basic stumbling block of any learning process. Take learning to cook for example. In quarantine, I’ve attempted ~60 new recipes and can recreate ~45 of them consistently. The recipes are my training set and the fraction I can recreate is a sort of training error. While this training error is not exactly impressive, if you ask me to riff on these recipes and improvise, the result (i.e. dinner) will be dramatically worse.</p>
<p style="text-align: center;"><img src="/assets/2020-08-03-data-copying/supervised_overfitting_2.png" width="75%" /></p>
<p>It is well understood that our models tend to do the same – deftly regurgitating their training data, yet struggling to generalize to unseen examples similar to the training data. Learning theory has nicely formalized this in the supervised setting. Our classification and regression models start to overfit when we observe a gap between training and (held-out) test prediction error, as in the above figure for the overly complex models.</p>
<p>This notion of overfitting relies on being able to measure prediction error or perhaps log likelihood of the labels, which is rarely a barrier in the supervised setting; supervised models generally output low dimensional, simple predictions. Such is not the case in the generative setting where we ask models to output original, high dimensional, complex entities like images or natural language. Here, we certainly lack any notion of prediction error and likelihoods are intractable for many of today’s generative models like VAEs and GANs: VAEs only provide a lower bound of the data likelihood, and GANs only leave us with their samples. This prevents us from simply measuring the gap between train and test accuracy/likelihood and calling it a day as we do with supervised models.</p>
<p>Instead, we evaluate generative models by comparing their generated samples with those of the true distribution, as in the following figure. Here, a two-sample test only uses a training sample and a generated sample. A three-sample test uses an additional held out test sample from the true distribution.</p>
<p style="text-align: center;"><img src="/assets/2020-08-03-data-copying/unsupervised_setting_2.png" width="75%" /></p>
<p>This practice is well established by existing two-sample generative model tests like the <a href="https://arxiv.org/abs/1706.08500">Frechet Inception Distance</a>, <a href="https://arxiv.org/abs/1611.04488">Kernel MMD</a>, and <a href="https://arxiv.org/abs/1806.00035">Precision & Recall test</a>. But in absence of ground truth labels, what exactly are we testing for? We argue that unlike supervised models, generative models exhibit two varieties of overfitting: <strong>over-representation</strong> and <strong>data-copying</strong>.</p>
<h3 id="data-copying-vs-over-representation">Data-copying vs. Over-representation</h3>
<p>Most generative model tests like those listed above check for over-representation: the tendency of a model to over-emphasize certain regions of the instance space by assigning more probability mass there than it should. Consider a data distribution $P$ over an instance space $\mathcal{X}$ of cat cartoons. Region $\mathcal{C} \subset \mathcal{X}$ specifically contains cartoons of cats with bats. Using training set $T \sim P$, we train a generative model $Q$ from which we draw a sample $Q_m \sim Q$.</p>
<p style="text-align: center;"><img src="/assets/2020-08-03-data-copying/overrepresentation.png" width="95%" /></p>
<p>Evidently, the model $Q$ really likes region $\mathcal{C}$, generating an undue share of cats with bats. More formally, we say $Q$ is over-representing some region $\mathcal{C}$ when</p>
<p>\[ \Pr_{x \sim Q}[x \in \mathcal{C}] \gg \Pr_{x \sim P}[x \in \mathcal{C}] \]</p>
<p>This can be measured with a simple two-sample hypothesis test, as was done in Richardson & Weiss’s <a href="https://arxiv.org/abs/1805.12462">2018 paper</a> demonstrating the efficacy of Gaussian mixture models in high dimension.</p>
<p>Data-copying, on the other hand, occurs when $Q$ produces samples that are <em>closer to training set $T$</em> than they should be. To test for this, we equip ourselves with a held-out test sample $P_n \sim P$ in addition to some distance metric $d(x,T)$ that measures proximity to the training set of any $x \in \mathcal{X}$. We then say that $Q$ is data-copying training set $T$ when examples $x \sim Q$ are on average closer to $T$ than are $x \sim P$.</p>
<p style="text-align: center;"><img src="/assets/2020-08-03-data-copying/data_copying_1_.png" width="95%" /></p>
<p>We define proximity to training set $d(x,T)$ to be the distance between $x$ and its nearest neighbor in $T$ according to some metric $d_\mathcal{X}:\mathcal{X} \times \mathcal{X} \rightarrow \mathbb{R}$. Specifically</p>
<p>\[ d(x,T) = \min_{t \in T}d_\mathcal{X}(x,t) \]</p>
<p>At a first glance, the generated samples in the above figure look perfectly fine, representing the different regions nicely. But taken alongside its training and test sets, we see that it has effectively copied the cat with bat in the lower right corner (for visualization, we let Euclidean distance $d_\mathcal{X}$ be a proxy for similarity).</p>
<p style="text-align: center;"><img src="/assets/2020-08-03-data-copying/data_copying_2.png" width="95%" /></p>
<p>More formally, $Q$ is data-copying $T$ in some region $\mathcal{C} \subset \mathcal{X}$ when</p>
<p>\[ \Pr_{x \sim Q, z \sim P}[d(x,T) < d(z,T) \mid x,z \in \mathcal{C}] \gg \frac{1}{2}\]</p>
<p>The key takeaway here is that data-copying and over-representation are <em>orthogonal failure modes</em> of generative models. A model that exhibits over-representation may or may not data-copy and vice versa. As such, it is critical that we test for both failure modes when designing and training models.</p>
<p style="text-align: center;"><img src="/assets/2020-08-03-data-copying/orthogonal_concepts_2_.png" width="70%" /></p>
<p>Returning to my failed culinary ambitions, I tend to both data-copy recipes I’ve tried <em>and</em> over-represent certain types of cuisine. If you look at the ‘true distribution’ of recipes online, you will find that there is a tremendous diversity of cooking styles and cuisines. However, put in the unfortunate circumstance of having me cook for you, I will most likely produce some slight variation of a recipe I’ve recently tried. And, even though I have attempted a number of Indian, Mexican, Italian, and French dishes, I tend to over-represent bland pastas and salads when left to my own devices. To cook truly original food, one must both be creative enough to go beyond the recipes they’ve seen <em>and</em> versatile enough to make a variety of cuisines. So, be sure to test for both data-copying and over-representation, and do not ask me to cook for you.</p>
<h3 id="a-three-sample-test-for-data-copying">A Three-Sample Test for Data-Copying</h3>
<p>Adding another test to one’s modeling pipeline is tedious. The good news is that data-copying can be tested with a single snappy three-sample hypothesis test. It is non-parametric, and concentrates nicely with both increasing test-set and generated samples.</p>
<p>As described in the previous section, we use a training sample $T \sim P$, a held-out test sample $P_n \sim P$, and a generated sample $Q_m \sim Q$. We additionally need some distance metric $d_\mathcal{X}(x,z)$. In practice, we choose $d_\mathcal{X}(x,z)$ to be the Euclidean distance between $x$ and $z$ after being embedded by $\phi$ into some lower-dimensional perceptual space: $d_\mathcal{X}(x,z) = \| \phi(x) - \phi(z) \|_2$. The use of such embeddings is common practice in testing generative models as exhibited by several existing over-representation tests like <a href="https://arxiv.org/abs/1706.08500">Frechet Inception Distance</a> and <a href="https://arxiv.org/abs/1806.00035">Precision & Recall</a>.</p>
<p>Following intuition, it is tempting to check for data-copying by simply differencing the expected distance to training set:</p>
<div>
$$
\mathbb{E}_{x \sim Q} [d(x,T)] - \mathbb{E}_{x \sim P} [d(x,T)] \approx \frac{1}{m} \sum_{x_i \in Q_m} d(x_i, T) - \frac{1}{n} \sum_{x_i \in P_n}d(x_i, T) \ll 0
$$
</div>
<p>where, to reiterate, $d(x,T)$ is the distance $d_\mathcal{X}$ between $x$ and its nearest neighbor in $T$. This statistic — an expected distance — is a little too finicky: the variance is far out of our control, influenced by both the choice of distance metric and by outliers in both $P_n$ and $Q_m$. So, instead of probing for how <em>much</em> closer $Q$ is to $T$ than $P$ is, we probe for how <em>often</em> $Q$ is closer to $T$ than $P$ is:</p>
<div>
$$
\mathbb{E}_{x \sim Q, z \sim P} [\mathbb{1}_{d(x,T) > d(z,T)}] \approx \frac{1}{nm} \sum_{x_i \in Q_m, z_j \in P_n} \mathbb{1} \big( d(x_i, T) > d(z_j, T) \big) \ll \frac{1}{2}
$$
</div>
<p>This statistic — a probability — is closer to what we want to measure, and is more stable. It tells us how much more likely samples in $Q_m$ are to fall near samples in $T$ relative to the held out samples in $P_n$. If it is much less than a half, then significant data-copying is occurring. This statistic is much more robust to outliers and is lower variance. Additionally, by measuring a probability instead of an expected distance, the value of this statistic is interpretable. Regardless of the data domain or distance metric, less than half is overfit, half is good, and over half is underfit (in the sense that the generated samples are further from the training set than they should be). We are also able to show that this indicator statistic has nice concentration properties agnostic to the chosen distance metric.</p>
<p>It turns out that the above test is an instantiation of the <a href="https://en.wikipedia.org/wiki/Mann-Whitney_U_test">Mann-Whitney hypothesis test</a>, proposed in 1947, for which there are computationally efficient implementations in packages like <a href="https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.mannwhitneyu.html">SciPy</a>. By $Z$-scoring the Mann-Whitney statistic, we normalize its mean to zero and variance to one. We call this statistic $Z_U$. As such, a generative model $Q$ with $Z_U \ll 0$ is heavily data-copying and a score $Z_U \gg 0$ is underfitting. Near 0 is ideal.</p>
<h3 id="handling-heterogeneity">Handling Heterogeneity</h3>
<p>An operative phrase that you may have noticed in the above definition of data-copying is “on average”. Is the generative model closer to the training data than it should be <em>on average</em>? This, unfortunately, is prone to false negatives. If $Z_U \ll 0$, then $Q$ is certainly data-copying in some region $\mathcal{C} \subset \mathcal{X}$. However, if $Z_U \geq 0$, it may still be excessively data-copying in one region and significantly underfitting in another, leading to a test score near 0.</p>
<p style="text-align: center;"><img src="/assets/2020-08-03-data-copying/bins_1.png" width="33%" /></p>
<p>For example, let the $\times$’s denote training samples and the red dots denote generated samples. Even without observing a held-out test sample, it is clear that $Q$ is data-copying in pink region and underfitting in the green region. $Z_U$ will fall near 0, suggesting the model is performing well despite this highly undesirable behavior.</p>
<p>To prevent this misreading, we employ an algorithmic tool seen frequently in non-parametric testing: binning. Break the instance space into a partition $\Pi$ consisting of $k$ ‘bins’ or ‘cells’ $\pi \in \Pi$ and collect $Z_U^\pi$ each cell $\pi$.</p>
<p style="text-align: center;"><img src="/assets/2020-08-03-data-copying/bins_2.png" width="33%" /></p>
<p>The statistic maintains its concentration properties within each cell. The more test and generated samples we have ($n$ and $m$), the more bins we can construct, and the more we can precisely pinpoint a model’s data-copying behavior. The ‘goodness’ of model’s fit is an inherently multidimensional entity, and it is informative to explore the range of $Z_U^\pi$ values seen in all cells $\pi \in \Pi$. Our experiments indicate that VAEs and GANs both tend to data-copy in some cells and underfit in others. However, to boil all this down into a single statistic for model comparisons, we simply take an average of the $Z_U^\pi$ values weighted by the number of test samples in the cell:</p>
<div>
$$
C_T = \sum_{\pi \in \Pi} \frac{\#\{P_n \in \pi\}}{n} Z_U^\pi
$$
</div>
<p>(In practice, we restrict ourselves to cells with a sufficient number of generated samples. See the <a href="https://arxiv.org/abs/2004.05675">paper</a>.). Intuitively, this statistic tells us whether the model tends to data-copy in the regions most heavily emphasized by the true distribution. It does not tell us whether or not the model $Q$ data-copies <em>somewhere</em>.</p>
<h3 id="experiments-data-copying-in-the-wild">Experiments: data-copying in the wild</h3>
<p>Observing data-copying in VAEs and GANs indicates that the $C_T$ statistic above serves as an instructive tool for model selection. For a more methodical interrogation of the $C_T$ statistic and comparison with baseline tests, be sure to check out the <a href="https://arxiv.org/abs/2004.05675">paper</a>.</p>
<p>To test how VAE complexity relates to data-copying, we train 20 VAEs on MNIST with increasing width as indicated by the latent dimension. For each model $Q$, we draw a sample of generated images $Q_m$, and compare with a held out test set $P_n$ to measure $C_T$. Our distance metric is given by the 64d latent space of an autoencoder we trained with a VGG perceptual loss produced by <a href="https://arxiv.org/abs/1801.03924">Zhang et al.</a>. The purpose of this alternative latent space is to provide an embedding that both provides a perceptual distance between images and is independent of the VAE embeddings. For partitioning, we simply take the Voronoi cells induced by the $k$ centroids found by $k$-means run on the embedded training dataset.</p>
<p style="text-align: center;"><img src="/assets/2020-08-03-data-copying/VAE_overfitting.png" width="49%" />
<img src="/assets/2020-08-03-data-copying/VAE_gen_gap.png" width="46%" /></p>
<h5 style="text-align: center;">The data-copying $C_T$ statistic (left) captures overfitting in overly complex VAEs. The train/test gap in ELBO (right), meanwhile, does not.</h5>
<p>Recall that $C_T \ll 0$ indicates data-copying and $C_T \gg 0$ indicates underfitting. We see (above, left) that overly complex models (towards the left of the plot) tend to copy their training set, and simple models (towards the right of the plot) tend to underfit, just as we might expect. Furthermore, $C_T = 0$ approximately coincides with the maximum ELBO, the VAE’s likelihood lower bound. For comparison, take the generalization gap of the VAEs’ ELBO on the training and test sets (above, right). The gap remains large for both overly complex models ($d > 50$) and simple models ($d < 50$). With the ELBO being a lower bound to the likelihood, it is difficult to interpret precisely why this happens. Regardless, it is clear that the ELBO gap is a compartively imprecise measure of overfitting.</p>
<p>While the VAEs exhibit increasing data-copying with model complexity <em>on average</em>, most of them have cells that are over- and underfit. Poking into the individual cells $\pi \in \Pi$, we can take a look at the difference between a $Z_U^\pi \ll 0$ cell and a $Z_U^\pi \gg 0$ cell:</p>
<p style="text-align: center;"><img src="/assets/2020-08-03-data-copying/VAE_cells.png" width="90%" /></p>
<h5 style="text-align: center;"> A VAE's datacopied (left) vs. underfit (right) cells of the MNIST instance space.</h5>
<p>The two strips exhibit two regions of the same VAE. The bottom row of each shows individual generated samples from the cell, and the top row shows their training nearest neighbors. We immediately see that the data-copied region (left, $Z_U^\pi = -8.54$) practically produces blurry replicas of its training nearest neighbors, while the underfit region (right, $Z_U^\pi = +3.3)$ doesn’t appear to produce samples that look like any training image.</p>
<p>Extending these tests to a more complex and practical domain, we check the ImageNet-trained <a href="https://arxiv.org/abs/1809.11096">BigGAN</a> model for data-copying. Being a conditional GAN that can output images of any single ImageNet 12 class, we condition on three separate classes and treat them as three separate models: Coffee, Soap Bubble, and Schooner. Here, it is not so simple to re-train GANs of varying degrees of complexity as we did before with VAEs. Instead, we modulate the model’s ‘trunction threshold’: a level beyond which all inputs are resampled. A larger truncation threshold allows for higher variance latent input, and thus higher variance outputs.</p>
<p style="text-align: center;"><img src="/assets/2020-08-03-data-copying/GAN_overfitting.png" width="60%" /></p>
<h5 style="text-align: center;"> BigGan, an ImageNet12 conditional GAN, appears to significantly data-copy for all but its highest truncation levels, which are said to trade off between variety and fidelity. </h5>
<p>Low truncation thresholds restrict the model to producing samples near the mode – those it is most confident in. However it appears that in all image classes, this also leads to significant data copying. Not only are the samples less diverse, but they hang closer to the training set than they should. This contrasts with the BigGAN authors’ suggestion that truncation level trades off between ‘variety and fidelity’. It appears that it might trade off between ‘copying and not copying’ the training set.</p>
<p>Again, even the least copying models with maximized truncation (=2) exhibit data-copying in <em>some</em> cells $\pi \in \Pi$:</p>
<p style="text-align: center;"><img src="/assets/2020-08-03-data-copying/GAN_cells.png" width="95%" /></p>
<h5 style="text-align: center;"> Examples from BigGan's data-copied (left) and underfit (right) cells of the 'coffee' (top) and 'soap bubble' (bottom) classes.</h5>
<p>The left two strips show show data-copied cells of the coffee and bubble instance spaces (low $Z_U^\pi$), and right two strips show underfit cells (high $Z_U^\pi$). The bottom row of each strip shows a subset of generated images from that cell, and the top row training images from the cell. To show the diversity of the cell, these are not necessarily the generated samples’ training nearest neighbors as they were in the MNIST example.</p>
<p>We see that the data-copied cells on the left tend to confidently produce samples of one variety, that linger too closely to some specific examples it caught in the training set. In the coffee case, it is the teacup/saucer combination. In the bubble case, it is the single large suspended bubble with blurred background. Meanwhile, the slightly underfit cells on the right arguably perform better in a ‘generative’ sense. The samples, albeit slightly distorted, are more original. According to the inception space distance metric, they hug less closely to the training set.</p>
<h3 id="data-copying-is-a-real-failure-mode-of-generative-models">Data-copying is a real failure mode of generative models</h3>
<p>The moral of these experiments is that data-copying indeed occurs in contemporary generative models. This failure mode has significant consequences for user privacy and for model generalization. With that said, it is a failure mode not identified by most prominent generative model tests in the literature today.</p>
<ul>
<li>
<p>Data-copying is <em>orthogonal to</em> over-representation; both should be tested when designing and training generative models.</p>
</li>
<li>
<p>Data-copying is straightforward to test efficiently when equipped with a decent distance metric.</p>
</li>
<li>
<p>Having identified this failure mode, it would be interesting to see modeling techniques that actively try to minimize data-copying in training.</p>
</li>
</ul>
<p>So be sure to start probing your models for data-copying, and don’t be afraid to venture off-recipe every once in a while!</p>
<h3 id="more-details">More Details</h3>
<p>Check out <a href="https://arxiv.org/abs/2004.05675">our AISTATS paper on arxiv</a>, and <a href="https://github.com/casey-meehan/data-copying">our data-copying test code on GitHub</a>.</p><a href='mailto:cmeehan@eng.ucsd.edu'>Casey Meehan</a>What does it mean for a generative model to overfit? We formalize the notion of 'data-copying', when a generative model produces only slight variations of the training set and fails to express the diversity of the true distribution. To catch this form of overfitting, we propose a three-sample hypothesis test that is entirely model agnostic. Our experiments indicate that several standard tests condone data-copying, and contemporary generative models like VAEs and GANs can commit data-copying.