Skip to content

Refactored Apriori implementation with correct pruning and candidate … #12697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 78 additions & 56 deletions machine_learning/apriori_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Apriori Algorithm is a Association rule mining technique, also known as market basket
Apriori Algorithm is an Association rule mining technique, also known as market basket
analysis, aims to discover interesting relationships or associations among a set of
items in a transactional or relational database.

Expand All @@ -11,7 +11,8 @@
Examples: https://www.kaggle.com/code/earthian/apriori-association-rules-mining
"""

from itertools import combinations
from collections import defaultdict

Check failure on line 15 in machine_learning/apriori_algorithm.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

machine_learning/apriori_algorithm.py:14:1: I001 Import block is un-sorted or un-formatted


def load_data() -> list[list[str]]:
Expand All @@ -24,36 +25,30 @@
return [["milk"], ["milk", "butter"], ["milk", "bread"], ["milk", "bread", "chips"]]


def prune(itemset: list, candidates: list, length: int) -> list:
def prune(
frequent_itemsets: list[list[str]], candidates: list[list[str]]
) -> list[list[str]]:
"""
Prune candidate itemsets that are not frequent.
The goal of pruning is to filter out candidate itemsets that are not frequent. This
is done by checking if all the (k-1) subsets of a candidate itemset are present in
the frequent itemsets of the previous iteration (valid subsequences of the frequent
itemsets from the previous iteration).

Prunes candidate itemsets that are not frequent.

>>> itemset = ['X', 'Y', 'Z']
>>> candidates = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
>>> prune(itemset, candidates, 2)
[['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]

>>> itemset = ['1', '2', '3', '4']
>>> candidates = ['1', '2', '4']
>>> prune(itemset, candidates, 3)
[]
Prunes candidate itemsets by ensuring all (k-1)-subsets exist in previous frequent itemsets.

Check failure on line 32 in machine_learning/apriori_algorithm.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

machine_learning/apriori_algorithm.py:32:89: E501 Line too long (96 > 88)

>>> frequent_itemsets = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
>>> candidates = [['X', 'Y', 'Z'], ['X', 'Y', 'W']]
>>> prune(frequent_itemsets, candidates)
[['X', 'Y', 'Z']]
"""
pruned = []

previous_frequents = set(frozenset(itemset) for itemset in frequent_itemsets)

Check failure on line 40 in machine_learning/apriori_algorithm.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (C401)

machine_learning/apriori_algorithm.py:40:26: C401 Unnecessary generator (rewrite as a set comprehension)

pruned_candidates = []
for candidate in candidates:
is_subsequence = True
for item in candidate:
if item not in itemset or itemset.count(item) < length - 1:
is_subsequence = False
break
if is_subsequence:
pruned.append(candidate)
return pruned
all_subsets_frequent = all(
frozenset(subset) in previous_frequents
for subset in combinations(candidate, len(candidate) - 1)
)
if all_subsets_frequent:
pruned_candidates.append(candidate)

return pruned_candidates


def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], int]]:
Expand All @@ -62,52 +57,79 @@

>>> data = [['A', 'B', 'C'], ['A', 'B'], ['A', 'C'], ['A', 'D'], ['B', 'C']]
>>> apriori(data, 2)
[(['A', 'B'], 1), (['A', 'C'], 2), (['B', 'C'], 2)]
[(['A'], 4), (['B'], 3), (['C'], 3), (['A', 'B'], 2), (['A', 'C'], 2), (['B', 'C'], 2)]

Check failure on line 60 in machine_learning/apriori_algorithm.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

machine_learning/apriori_algorithm.py:60:89: E501 Line too long (91 > 88)

>>> data = [['1', '2', '3'], ['1', '2'], ['1', '3'], ['1', '4'], ['2', '3']]
>>> apriori(data, 3)
[]
[(['1'], 4), (['2'], 3), (['3'], 3)]
"""
itemset = [list(transaction) for transaction in data]
frequent_itemsets = []
length = 1

while itemset:
# Count itemset support
counts = [0] * len(itemset)
for transaction in data:
for j, candidate in enumerate(itemset):
if all(item in transaction for item in candidate):
counts[j] += 1
item_counts = defaultdict(int)
for transaction in data:
for item in transaction:
item_counts[item] += 1

current_frequents = [
[item] for item, count in item_counts.items() if count >= min_support
]
frequent_itemsets = [
([item], count) for item, count in item_counts.items() if count >= min_support
]

# Prune infrequent itemsets
itemset = [item for i, item in enumerate(itemset) if counts[i] >= min_support]
k = 2
while current_frequents:
candidates = [
sorted(list(set(i) | set(j)))

Check failure on line 82 in machine_learning/apriori_algorithm.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (C414)

machine_learning/apriori_algorithm.py:82:13: C414 Unnecessary `list()` call within `sorted()`
for i in current_frequents
for j in current_frequents
if len(set(i).union(j)) == k
]

# Append frequent itemsets (as a list to maintain order)
for i, item in enumerate(itemset):
frequent_itemsets.append((sorted(item), counts[i]))
candidates = [list(c) for c in {frozenset(c) for c in candidates}]

length += 1
itemset = prune(itemset, list(combinations(itemset, length)), length)
candidates = prune(current_frequents, candidates)

return frequent_itemsets
candidate_counts = defaultdict(int)
for transaction in data:
t_set = set(transaction)
for candidate in candidates:
if set(candidate).issubset(t_set):
candidate_counts[tuple(sorted(candidate))] += 1

current_frequents = [
list(key) for key, count in candidate_counts.items() if count >= min_support
]
frequent_itemsets.extend(
[
(list(key), count)
for key, count in candidate_counts.items()
if count >= min_support
]
)

k += 1

return sorted(frequent_itemsets, key=lambda x: (len(x[0]), x[0]))


if __name__ == "__main__":
"""
Apriori algorithm for finding frequent itemsets.

Args:
data: A list of transactions, where each transaction is a list of items.
min_support: The minimum support threshold for frequent itemsets.
This script loads sample transaction data and runs the Apriori algorithm
with a user-defined minimum support threshold.

Returns:
A list of frequent itemsets along with their support counts.
The result is a list of frequent itemsets along with their support counts.
"""
import doctest

doctest.testmod()

# user-defined threshold or minimum support level
frequent_itemsets = apriori(data=load_data(), min_support=2)
print("\n".join(f"{itemset}: {support}" for itemset, support in frequent_itemsets))
transactions = load_data()
min_support_threshold = 2

frequent_itemsets = apriori(transactions, min_support=min_support_threshold)

print("Frequent Itemsets:")
for itemset, support in frequent_itemsets:
print(f"{itemset}: {support}")
Loading