Federated Learning Simulator

LAB17: Federated Learning with Flower

Visualize Distributed Model Training

See how federated learning aggregates knowledge from multiple clients while keeping data local.

Concept from LAB17

See Section 17.2: The FedAvg Algorithm in the PDF book.

Interactive FL Simulator

Convergence Over Rounds

Client Participation Matrix

Code
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

# Simulate FL training
num_rounds = 10
num_clients = 5

# Simulated accuracy trajectories
def simulate_fl_training(num_rounds, num_clients, iid=True):
    """Simulate federated learning convergence"""
    global_accuracy = [0.1]  # Start at random

    for round_num in range(num_rounds):
        # Each client trains locally
        client_accuracies = []
        for c in range(num_clients):
            if iid:
                # IID data: similar improvement
                improvement = 0.08 + np.random.normal(0, 0.01)
            else:
                # Non-IID: more variance
                improvement = 0.06 + np.random.normal(0, 0.03)

            client_acc = min(0.99, global_accuracy[-1] + improvement)
            client_accuracies.append(client_acc)

        # FedAvg aggregation
        new_global = np.mean(client_accuracies)
        global_accuracy.append(new_global)

    return global_accuracy

# Simulate IID vs Non-IID
iid_accuracy = simulate_fl_training(num_rounds, num_clients, iid=True)
non_iid_accuracy = simulate_fl_training(num_rounds, num_clients, iid=False)

# Also simulate centralized training for comparison
centralized_accuracy = [0.1]
for _ in range(num_rounds):
    improvement = 0.09 + np.random.normal(0, 0.005)
    centralized_accuracy.append(min(0.99, centralized_accuracy[-1] + improvement))

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Convergence comparison
ax = axes[0]
rounds = range(num_rounds + 1)
ax.plot(rounds, centralized_accuracy, 'g-o', label='Centralized', linewidth=2)
ax.plot(rounds, iid_accuracy, 'b-s', label='FL (IID data)', linewidth=2)
ax.plot(rounds, non_iid_accuracy, 'r-^', label='FL (Non-IID data)', linewidth=2)
ax.set_xlabel('Round')
ax.set_ylabel('Accuracy')
ax.set_title('Federated vs Centralized Training')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1)

# Client participation
ax = axes[1]
client_data = np.random.rand(num_rounds, num_clients) > 0.2  # 80% participation
im = ax.imshow(client_data.T, cmap='Greens', aspect='auto')
ax.set_xlabel('Round')
ax.set_ylabel('Client ID')
ax.set_title('Client Participation per Round')
ax.set_yticks(range(num_clients))
ax.set_yticklabels([f'Client {i+1}' for i in range(num_clients)])

plt.tight_layout()
plt.show()
Figure 26.1: Federated learning convergence across rounds

FedAvg Algorithm

The Federated Averaging algorithm:

\[w_{t+1} = \sum_{k=1}^{K} \frac{n_k}{n} w_k^{t+1}\]

where: - \(K\) = number of clients - \(n_k\) = samples at client \(k\) - \(n = \sum_k n_k\) = total samples - \(w_k^{t+1}\) = client \(k\)’s model after local training

Code
fig, ax = plt.subplots(figsize=(10, 6))

# Visualize weighted aggregation
clients = ['Client 1\n(10K samples)', 'Client 2\n(5K samples)', 'Client 3\n(3K samples)']
samples = [10000, 5000, 3000]
weights = np.array(samples) / sum(samples)
colors = plt.cm.Blues(np.linspace(0.4, 0.8, len(clients)))

# Draw clients
for i, (client, w, color) in enumerate(zip(clients, weights, colors)):
    ax.add_patch(plt.Rectangle((i*2, 0), 1.5, 3, color=color, alpha=0.7))
    ax.text(i*2 + 0.75, 1.5, client, ha='center', va='center', fontsize=10)
    ax.text(i*2 + 0.75, -0.5, f'Weight: {w:.2f}', ha='center', fontsize=9)

# Draw aggregation arrow
ax.annotate('', xy=(3, 5), xytext=(3, 3.5),
           arrowprops=dict(arrowstyle='->', lw=2, color='green'))
ax.text(3, 4.25, 'Weighted\nAverage', ha='center', fontsize=10)

# Draw global model
ax.add_patch(plt.Rectangle((2, 5.5), 2, 1.5, color='green', alpha=0.7))
ax.text(3, 6.25, 'Global Model', ha='center', va='center', fontsize=11, fontweight='bold')

ax.set_xlim(-0.5, 6.5)
ax.set_ylim(-1, 8)
ax.axis('off')
ax.set_title('FedAvg: Weighted Model Aggregation', fontsize=12)
plt.show()
Figure 26.2: FedAvg weighted aggregation

Non-IID Data Challenge

Code
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# IID distribution
ax = axes[0]
for c in range(3):
    data = np.random.randint(0, 10, 100)
    ax.hist(data, bins=10, alpha=0.5, label=f'Client {c+1}')
ax.set_xlabel('Class Label')
ax.set_ylabel('Count')
ax.set_title('IID Data Distribution')
ax.legend()

# Non-IID distribution (each client has different classes)
ax = axes[1]
for c in range(3):
    # Each client mainly has 2-3 classes
    main_classes = [c*3, c*3+1, c*3+2 if c < 2 else 0]
    data = np.random.choice(main_classes, 100)
    ax.hist(data, bins=10, range=(0, 10), alpha=0.5, label=f'Client {c+1}')
ax.set_xlabel('Class Label')
ax.set_ylabel('Count')
ax.set_title('Non-IID Data Distribution')
ax.legend()

plt.tight_layout()
plt.show()
Figure 26.3: IID vs Non-IID data distribution

Key Insights

Aspect IID Data Non-IID Data
Convergence Fast, stable Slower, may oscillate
Final Accuracy ~Same as centralized May be lower
Solution Standard FedAvg FedProx, FedNova

Try It Yourself

# Simulate FL with Flower
import flwr as fl

# Start server
fl.server.start_server(
    config=fl.server.ServerConfig(num_rounds=5),
)

# Start clients (run in separate terminals)
fl.client.start_numpy_client(
    server_address="localhost:8080",
    client=MyClient()
)