Skip to content

Flatten a List in Python: Effortlessly Simplify Nested Structures

[

How to Flatten a List of Lists in Python

Sometimes, when you’re working with data, you may have the data as a list of nested lists. A common operation is to flatten this data into a one-dimensional list in Python. Flattening a list involves converting a multidimensional list, such as a matrix, into a one-dimensional list.

To better illustrate what it means to flatten a list, say that you have the following matrix of numeric values:

Python

matrix = [
[9, 3, 8, 3],
[4, 5, 2, 8],
[6, 4, 3, 1],
[1, 0, 4, 5],
]

Python

[9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5]

How do you manage to flatten your matrix and get a one-dimensional list like the one above? In this tutorial, you’ll learn how to do that in Python.

How to Flatten a List of Lists With a for Loop

How can you flatten a list of lists in Python? In general, to flatten a list of lists, you can run the following steps either explicitly or implicitly:

  1. Create a new empty list to store the flattened data.
  2. Iterate over each nested list or sublist in the original list.
  3. Add every item from the current sublist to the list of flattened data.
  4. Return the resulting list with the flattened data.

To continue with the matrix example, here’s how you would translate these steps into Python code using a for loop and the .extend() method:

Python

def flatten_extend(matrix):
flat_list = []
for row in matrix:
flat_list.extend(row)
return flat_list

Inside flatten_extend(), you first create a new empty list called flat_list. You’ll use this list to store the flattened data when you extract it from matrix. Then you start a loop to iterate over the inner, or nested, lists from matrix. In this example, you use the name row to represent the current nested list.

Now go ahead and run the following code to check that your function does the job:

Python

matrix = [
[9, 3, 8, 3],
[4, 5, 2, 8],
[6, 4, 3, 1],
[1, 0, 4, 5],
]
result = flatten_extend(matrix)
print(result)

The output should be:

[9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5]

Congratulations! You have successfully flattened a list of lists using a for loop and the .extend() method.

Using a Comprehension to Flatten a List of Lists

Another way to flatten a list of lists is by using a comprehension. Comprehensions provide a concise and expressive way to create new lists based on existing ones.

To flatten a list of lists using a comprehension, you can use nested comprehensions. Here’s how you can do it:

Python

matrix = [
[9, 3, 8, 3],
[4, 5, 2, 8],
[6, 4, 3, 1],
[1, 0, 4, 5],
]
flat_list = [item for sublist in matrix for item in sublist]
print(flat_list)

The output will be the same as before:

[9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5]

In this example, the comprehension item for sublist in matrix for item in sublist creates a new list flat_list by iterating over each sublist in matrix and then iterating over each item in those sublists. The items are added to flat_list one by one.

Using a comprehension to flatten a list of lists can be more concise and readable than using a for loop, especially for simpler cases. However, keep in mind that comprehensions may not always be the most efficient option for more complex scenarios.

Flattening a List Using Standard-Library and Built-in Tools

Python provides several built-in functions and modules in its standard library that can help you flatten a list of lists without writing custom code. Let’s explore some of these tools.

Chaining Iterables With itertools.chain()

The itertools.chain() function allows you to merge multiple iterables into a single iterable. You can use it to flatten a list of lists by passing each sublist as an argument to chain().

Here’s an example:

Python

import itertools
matrix = [
[9, 3, 8, 3],
[4, 5, 2, 8],
[6, 4, 3, 1],
[1, 0, 4, 5],
]
flat_list = list(itertools.chain(*matrix))
print(flat_list)

The output will be the same as before:

[9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5]

In this example, the * operator unpacks the matrix list into separate arguments for chain(). chain() then combines these arguments into a single iterable, which is converted to a list using list().

Concatenating Lists With functools.reduce()

The functools.reduce() function applies a specified function to the first two elements of an iterable, then the result and the next element, and so on, until all elements have been processed. You can use it to concatenate the sublists of a list by applying the + operator as the function argument.

Here’s an example:

Python

import functools
matrix = [
[9, 3, 8, 3],
[4, 5, 2, 8],
[6, 4, 3, 1],
[1, 0, 4, 5],
]
flat_list = functools.reduce(lambda a, b: a + b, matrix)
print(flat_list)

The output will be the same as before:

[9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5]

In this example, the lambda function lambda a, b: a + b concatenates two lists by using the + operator. reduce() applies this function to the elements of matrix in a pairwise manner, resulting in the desired flattened list.

Using sum() to Concatenate Lists

The sum() function can also be used to concatenate lists. When passed a list of lists, it will start with an empty list and concatenate each sublist to the accumulated result.

Here’s an example:

Python

matrix = [
[9, 3, 8, 3],
[4, 5, 2, 8],
[6, 4, 3, 1],
[1, 0, 4, 5],
]
flat_list = sum(matrix, [])
print(flat_list)

The output will be the same as before:

[9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5]

In this example, sum(matrix, []) starts with an empty list [] and concatenates each sublist of matrix to it.

Considering Performance While Flattening Your Lists

When working with large datasets, the performance of your list-flattening code becomes important. Depending on the size of your input data and the complexity of the flattening process, some approaches may be more efficient than others.

To compare the performance of different methods, you can use the timeit module, which provides a way to measure the execution time of small bits of Python code.

Here’s an example that compares the execution times of the previously discussed methods:

Python

import itertools
import functools
import timeit
matrix = [
[9, 3, 8, 3],
[4, 5, 2, 8],
[6, 4, 3, 1],
[1, 0, 4, 5],
]
def flatten_extend(matrix):
flat_list = []
for row in matrix:
flat_list.extend(row)
return flat_list
def flatten_comprehension(matrix):
return [item for sublist in matrix for item in sublist]
def flatten_chain(matrix):
return list(itertools.chain(*matrix))
def flatten_reduce(matrix):
return functools.reduce(lambda a, b: a + b, matrix)
def flatten_sum(matrix):
return sum(matrix, [])
flatten_times = {
"flatten_extend": timeit.timeit(lambda: flatten_extend(matrix), number=1000),
"flatten_comprehension": timeit.timeit(lambda: flatten_comprehension(matrix), number=1000),
"flatten_chain": timeit.timeit(lambda: flatten_chain(matrix), number=1000),
"flatten_reduce": timeit.timeit(lambda: flatten_reduce(matrix), number=1000),
"flatten_sum": timeit.timeit(lambda: flatten_sum(matrix), number=1000),
}
for method, time in flatten_times.items():
print(f"{method}: {time:.6f} seconds")

This code uses the timeit.timeit() function to measure the execution time of each flattening method. It runs each method 1000 times and prints the execution time in seconds.

When you run this code, you may see different execution times depending on your system and the complexity of your data. However, you can generally compare the times to get an idea of the relative performance of each method.

Flattening Python Lists for Data Science With NumPy

If you’re working with data science applications and need to flatten multidimensional arrays, the NumPy library provides a powerful and efficient way to do so.

NumPy is a popular library in the Python data science ecosystem. It introduces the ndarray object, which is an efficient multi-dimensional container for homogeneous data. NumPy includes various functions and methods for array manipulation, including flattening arrays.

Here’s an example of how to flatten a multidimensional array using NumPy:

Python

import numpy as np
matrix = np.array([
[9, 3, 8, 3],
[4, 5, 2, 8],
[6, 4, 3, 1],
[1, 0, 4, 5],
])
flat_array = matrix.flatten()
print(flat_array)

The output will be the same as before:

[9 3 8 3 4 5 2 8 6 4 3 1 1 0 4 5]

In this example, you convert the matrix list to a NumPy ndarray using the np.array() function. Then you call .flatten() on the ndarray to get a flattened NumPy array.

Flattening arrays with NumPy can provide significant performance improvements compared to standard Python list operations, especially for large arrays and complex computations.

Conclusion

Flattening a list of lists is a common operation in Python, especially when working with nested data structures. In this tutorial, you’ve learned several methods to flatten lists, including using a for loop, comprehension, and built-in tools from the Python standard library.

You’ve also explored the performance considerations when flattening lists and how NumPy can be used for efficient flattening in data science applications.

By mastering the techniques covered in this tutorial, you’ll be able to easily flatten lists and analyze nested data structures in Python with confidence.