Inductive bias illustration

import numpy as np
import matplotlib.pyplot as plt
def true_function(x):
    rval = np.zeros_like(x)
    inds = np.logical_and(-1 < x, x < 1)
    rval[inds] = 1
    return rval

samples = np.random.uniform(-2, 2, 16)
labels = true_function(samples)
true_inds = np.where(labels == 1)[0]
false_inds = np.where(labels == 0)[0]

def build_heuristic_function(samples, labels):
    true_samples = samples[labels == 1]
    def heuristic_func(x):
        rval = -1
        for sample in true_samples:
            rval *= np.abs(x-sample)**(1/4)
        return (1 + rval)
    return heuristic_func
heuristic_func = build_heuristic_function(samples, labels)

def build_good_heuristic_function(samples, labels):
    min_good = np.min(samples[labels == 1])
    max_good = np.max(samples[labels == 1])
    def heuristic_func(x):
        rval = np.ones_like(x)
        rval[x < min_good] = 0
        rval[x > max_good] = 0
        return rval
    return heuristic_func
good_heuristic_func = build_good_heuristic_function(samples, labels)

domain_arr = np.linspace(-2, 2, 2048)
domain_arr = np.sort(np.concatenate([domain_arr, samples]))
fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(samples[true_inds], labels[true_inds], c='g', label='True')
ax.scatter(samples[false_inds], labels[false_inds], c='r', label='False')
#ax.plot(domain_arr, true_function(domain_arr), c='k', label='True Function')
#ax.plot(domain_arr, heuristic_func(domain_arr), c='b', label='Heuristic Function', alpha=0.75, ls='--')

ax.axhline(0, color='k', ls='--', alpha=0.5)
ax.axhline(1, color='k', ls='--', alpha=0.5)
ax.fill_betweenx([0, 1], -1, 1, color='g', alpha=0.1)
ax.fill_betweenx([0, 1], -2, -1, color='r', alpha=0.1)
ax.fill_betweenx([0, 1], 2, 1, color='r', alpha=0.1)

ax.set(xlabel='x', ylabel='y', ylim=[-0.1, 1.1], xlim=[-2.1, 2.1])
plt.tight_layout()
plt.savefig('heuristic_function_1.png', dpi=300)
plt.show()

fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(samples[true_inds], labels[true_inds], c='g', label='True')
ax.scatter(samples[false_inds], labels[false_inds], c='r', label='False')
#ax.plot(domain_arr, true_function(domain_arr), c='k', label='True Function')
ax.plot(domain_arr, good_heuristic_func(domain_arr), c='b', label='Heuristic Function', alpha=0.75, ls='--')

ax.axhline(0, color='k', ls='--', alpha=0.5)
ax.axhline(1, color='k', ls='--', alpha=0.5)
ax.fill_betweenx([0, 1], -1, 1, color='g', alpha=0.1)
ax.fill_betweenx([0, 1], -2, -1, color='r', alpha=0.1)
ax.fill_betweenx([0, 1], 2, 1, color='r', alpha=0.1)

ax.set(xlabel='x', ylabel='y', ylim=[-0.1, 1.1], xlim=[-2.1, 2.1])
plt.tight_layout()
plt.savefig('heuristic_function_2.png', dpi=300)
plt.show()

fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(samples[true_inds], labels[true_inds], c='g', label='True')
ax.scatter(samples[false_inds], labels[false_inds], c='r', label='False')
#ax.plot(domain_arr, true_function(domain_arr), c='k', label='True Function')
ax.plot(domain_arr, heuristic_func(domain_arr), c='b', label='Heuristic Function', alpha=0.75, ls='--')

ax.axhline(0, color='k', ls='--', alpha=0.5)
ax.axhline(1, color='k', ls='--', alpha=0.5)
ax.fill_betweenx([0, 1], -1, 1, color='g', alpha=0.1)
ax.fill_betweenx([0, 1], -2, -1, color='r', alpha=0.1)
ax.fill_betweenx([0, 1], 2, 1, color='r', alpha=0.1)

ax.set(xlabel='x', ylabel='y', ylim=[-0.1, 1.1], xlim=[-2.1, 2.1])
plt.tight_layout()
plt.savefig('heuristic_function_3.png', dpi=300)
plt.show()