Neural correlates of hidden mental models

This post describes an idea about interpreting artificial neural networks that I used in my computer science master's thesis. I'm genuinely curious whether this idea has any value, so I invite you to send feedback via email, including about:

  1. Does this even make any sense?
  2. Does this have a name already?
  3. Where did I make mistakes?
  4. Has this ever been used productively?

Context

Neural networks tend to be more accurate at classification problems than hand-crafted, explicit algorithms. Not limited by the preconceptions of the engineer, they can adapt to a given problem, pick up information that the engineer didn't consider relevant, and synthesize their own mental models.

From the desire to bridge the growing gap between human understanding and the accuracy of neural networks arose the field of neural network interpretability.

Finding the neural correlates

Any task performed by the neural network must have a combination of neurons that are responsible for it. In biological brains, this is called "neural correlate." I wondered, is it possible to find neural correlates in artificial neural networks as well? And in the context of interpretability, what are the neural correlates of the unknown mental models that give a neural network the edge over a hand-crafted algorithm?

I used the following approach for this in my 2018 Computer Science Master's thesis:

Given a system A, which is based on a neural network and which has better accuracy than system B (an arbitrary black box), let both systems classify a large number of samples, and then correlate the output values of each neuron in system A with a number that indicates whether system A was more accurate than system B. The neurons with the highest correlation coefficient are candidates for having encoded a property of the classification problem that system B does not take into account.

Algorithmically:

Reconstructing the model

Having identified candidates for neurons that encode new knowledge about the classification problem, we can narrow down or focus of analysis.

If our object of inspection is based on an image recognition neural network, we could analyze these neurons through feature visualization to get an idea for what these highly correlated neurons "see". E.g., here are some visualizations made with the software Lucid, showing various layers of a NN, from low-level layers that detect edges/patterns to high-level layers that detect complex objects:

example use of lucid

I should add that while I have applied this idea in my thesis, I unfortunately didn't learn much from visualizing the highly correlated neurons. This was likely because I was crazy enough to apply this to DeepVariant, which technically uses an image recognition network, but it's not trained on natural images, but synthetic 6-channel pile-up images. Therefore, my visualizations were mostly indecipherable noise. I still wonder whether this idea would be more fruitful when applied to a NN trained on natural images.

Example

For a better intuition, let's consider a real-world example. In 2020, we already have neural networks that detect medical issues in X-ray images where radiologists fail to detect them. If we would show 10,000 images to both the NN and the radiologist, we may find that there are 120 images where the NN correctly identified the disease but the radiologist didn't. Furthermore, we may find that there were 3 neurons in the NN whose outputs spiked almost exclusively when viewing those 120 images.

Those 3 neurons might pick up something that the radiologist can't. Correlation doesn't imply causation, but it's still worth taking a look. By applying a feature visualization method on those neurons, we might find, for example, that they are sensitive to particular ring-shaped patterns that the human eye has difficulties picking up. With this knowledge, we could devise an image filter that enhances these particular patterns to help the radiologist recognize them as well.