Questions about the Speed of DGL

I did some tests and found that the speed of certain basic operators is even slower than the vanilla pytorch implemention. Is this expected? Is there a general rule when using dgl can imporve the performance of your model?

Here are the codes and results. The tests were conducted using a laptop 3060 GPU running on windows.

import torch
import dgl

torch.__version__  #  '2.1.1'
dgl.__version__  # '2.0.0+cu121'


# pytorch functions

def unsorted_segment_sum(data, segment_ids, num_segments):
    """Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`."""
    result_shape = (num_segments, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result.scatter_add_(0, segment_ids, data)
    return result


def unsorted_segment_mean(data, segment_ids, num_segments):
    result_shape = (num_segments, data.size(1))
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    count = data.new_full(result_shape, 0)
    result.scatter_add_(0, segment_ids, data)
    count.scatter_add_(0, segment_ids, torch.ones_like(data))
    return result / count.clamp(min=1)

# basic setup

fv = torch.randn(20*128, 128).cuda()
fe = torch.randn(100*128, 64).cuda()

g = dgl.rand_graph(20, 100)
bg = dgl.batch([g]*128).to('cuda')

u,v = bg.edges()

# test 1 copy_e_sum

%%timeit 
unsorted_segment_sum(fe, u, len(fv))
# 58.5 µs ± 3.97 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

%%timeit
dgl.ops.copy_e_sum(bg, fe)
# 238 µs ± 3.15 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

# test 2 copy_e_mean

%%timeit 
unsorted_segment_mean(fe, u, len(fv))
# 171 µs ± 4.46 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

%%timeit
dgl.ops.copy_e_mean(bg, fe)
# 574 µs ± 7.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

# test 3  gsddmm sub

%%timeit
fv[u]-fv[v]
# 135 µs ± 865 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

%%timeit
dgl.ops.gsddmm(bg, 'sub', fv, fv)
# 215 µs ± 3.13 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

# test 4 gsddmm mul

%%timeit
fv[u]*fv[v]
# 135 µs ± 1.34 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

%%timeit
dgl.ops.gsddmm(bg, 'mul', fv, fv)
# 185 µs ± 14.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1 Like

I believe your tests use too small data. The runtimes you are measuring are basically measuring the python-C++ interoperability and you are not synchronizing after the operation so you are measuring the time before the operation ends.

You should use cuda events to measure the time.

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()

for _ in range(10000):
    # benchmarking ops here should be run 10000 times for accurate measurement.

end_event.record()

print(start_event.elapsed_time(end_event))

You can also use Benchmark Utils - torch.utils.benchmark — PyTorch 2.4 documentation.

1 Like

Thanks, the results make more sense now.

  • setup
"""
for n_graph in tqdm([128,1024,4096]):
    for n_node, n_edge in [(20,100), (200, 1_000), (200, 10_000)]:
        g = dgl.rand_graph(n_node, n_edge)
        fv = torch.randn(n_node*n_graph, 128).cuda()
        fe = torch.randn(n_edge*n_graph, 64).cuda()
        bg = dgl.batch([g]*n_graph).to('cuda')
        u,v = bg.edges()

        ...

"""
  • results
# n_graph, n_node, n_edge
[---------------- copy_e_sum ----------------]
                        |   torch   |    dgl  
1 threads: -----------------------------------
      [128,20,100]      |     31.1  |     83.8
      [128,200,1000]    |     25.3  |     69.2
      [128,200,10000]   |    381.5  |    386.3
      [1024,20,100]     |     22.1  |     73.3
      [1024,200,1000]   |    458.5  |    460.2
      [1024,200,10000]  |   3071.1  |   3084.5
      [4096,20,100]     |    158.0  |    159.1
      [4096,200,1000]   |   1834.3  |   1842.7
      [4096,200,10000]  |  12249.6  |  12297.7

Times are in microseconds (us).

[--------------- copy_e_mean ----------------]
                        |   torch   |    dgl  
1 threads: -----------------------------------
      [128,20,100]      |     58.5  |    201.5
      [128,200,1000]    |    112.3  |    223.4
      [128,200,10000]   |   1155.2  |    408.4
      [1024,20,100]     |     65.2  |    217.9
      [1024,200,1000]   |   1477.3  |    581.7
      [1024,200,10000]  |   9148.0  |   3200.3
      [4096,20,100]     |    559.4  |    209.9
      [4096,200,1000]   |   5883.1  |   2334.4
      [4096,200,10000]  |  36545.6  |  12785.7

Times are in microseconds (us).

[----------------- u add v -----------------]
                        |   torch   |   dgl  
1 threads: ----------------------------------
      [128,20,100]      |     25.6  |    75.2
      [128,200,1000]    |    385.7  |    98.1
      [128,200,10000]   |   3663.4  |   795.2
      [1024,20,100]     |    283.8  |    57.4
      [1024,200,1000]   |   3118.0  |   731.0
      [1024,200,10000]  |  29312.0  |  6446.4
      [4096,20,100]     |   1250.0  |   318.3
      [4096,200,1000]   |  12475.7  |  2937.6

Times are in microseconds (us).

[----------------- u mul v -----------------]
                        |   torch   |   dgl  
1 threads: ----------------------------------
      [128,20,100]      |     26.1  |    58.3
      [128,200,1000]    |    385.7  |    98.1
      [128,200,10000]   |   3663.4  |   795.3
      [1024,20,100]     |    283.8  |    55.1
      [1024,200,1000]   |   3117.8  |   731.5
      [1024,200,10000]  |  29314.0  |  6447.4
      [4096,20,100]     |   1250.1  |   318.3
      [4096,200,1000]   |  12477.2  |  2937.8

Times are in microseconds (us).
Testing Code
import torch
import dgl

torch.__version__  # '2.3.1'
dgl.__version__  # '2.3.0+cu121'
# gpu: a single RTX 4090

def unsorted_segment_sum(data, segment_ids, num_segments):
    """Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`."""
    result_shape = (num_segments, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result.scatter_add_(0, segment_ids, data)
    return result


def unsorted_segment_mean(data, segment_ids, num_segments):
    result_shape = (num_segments, data.size(1))
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    count = data.new_full(result_shape, 0)
    result.scatter_add_(0, segment_ids, data)
    count.scatter_add_(0, segment_ids, torch.ones_like(data))
    return result / count.clamp(min=1)

def torch_add(fv,u,v):
    return fv[u]+fv[v]

def torch_mul(fv,u,v):
    return fv[u]*fv[v]

from tqdm.auto import tqdm
import torch.utils.benchmark as benchmark

results = []

for n_graph in tqdm([128,1024,4096]):
    for n_node, n_edge in [(20,100), (200, 1_000), (200, 10_000)]:
        g = dgl.rand_graph(n_node, n_edge)
        fv = torch.randn(n_node*n_graph, 128).cuda()
        fe = torch.randn(n_edge*n_graph, 64).cuda()
        bg = dgl.batch([g]*n_graph).to('cuda')
        u,v = bg.edges()
    
        # copy_e_sum
        results.extend([
            benchmark.Timer(
                stmt='unsorted_segment_sum(fe, u, len(fv))',
                setup='from __main__ import unsorted_segment_sum',
                globals={'bg':bg,'fe': fe,'fv':fv,'u':u,'v':v},
                label='copy_e_sum',
                sub_label=f'[{n_graph},{n_node},{n_edge}]',
                description='torch',
            ).blocked_autorange(min_run_time=1),
            benchmark.Timer(
                stmt='dgl.ops.copy_e_sum(bg, fe)',
                setup='import dgl',
                globals={'bg':bg,'fe': fe,'fv':fv,'u':u,'v':v},
                label='copy_e_sum',
                sub_label=f'[{n_graph},{n_node},{n_edge}]',
                description='dgl',
            ).blocked_autorange(min_run_time=1),
        ])
    
        # copy_e_mean
        results.extend([
            benchmark.Timer(
                stmt='unsorted_segment_mean(fe, u, len(fv))',
                setup='from __main__ import unsorted_segment_mean',
                globals={'bg':bg,'fe': fe,'fv':fv,'u':u,'v':v},
                label='copy_e_mean',
                sub_label=f'[{n_graph},{n_node},{n_edge}]',
                description='torch',
            ).blocked_autorange(min_run_time=1),
            benchmark.Timer(
                stmt='dgl.ops.copy_e_mean(bg, fe)',
                setup='import dgl',
                globals={'bg':bg,'fe': fe,'fv':fv,'u':u,'v':v},
                label='copy_e_mean',
                sub_label=f'[{n_graph},{n_node},{n_edge}]',
                description='dgl',
            ).blocked_autorange(min_run_time=1),
        ])
    
        # add u v
        results.extend([
            benchmark.Timer(
                stmt='torch_add(fv,u,v)',
                setup='from __main__ import torch_add',
                globals={'bg':bg,'fe': fe,'fv':fv,'u':u,'v':v},
                label='u add v',
                sub_label=f'[{n_graph},{n_node},{n_edge}]',
                description='torch',
            ).blocked_autorange(min_run_time=1),
            benchmark.Timer(
                stmt="dgl.ops.gsddmm(bg, 'add', fv, fv)",
                setup='import dgl',
                globals={'bg':bg,'fe': fe,'fv':fv,'u':u,'v':v},
                label='u add v',
                sub_label=f'[{n_graph},{n_node},{n_edge}]',
                description='dgl',
            ).blocked_autorange(min_run_time=1),
        ])
    
        # mul u v
        results.extend([
            benchmark.Timer(
                stmt='torch_mul(fv,u,v)',
                setup='from __main__ import torch_mul',
                globals={'bg':bg,'fe': fe,'fv':fv,'u':u,'v':v},
                label='u mul v',
                sub_label=f'[{n_graph},{n_node},{n_edge}]',
                description='torch',
            ).blocked_autorange(min_run_time=1),
            benchmark.Timer(
                stmt="dgl.ops.gsddmm(bg, 'mul', fv, fv)",
                setup='import dgl',
                globals={'bg':bg,'fe': fe,'fv':fv,'u':u,'v':v},
                label='u mul v',
                sub_label=f'[{n_graph},{n_node},{n_edge}]',
                description='dgl',
            ).blocked_autorange(min_run_time=1),
        ])
    
        del g,bg,u,v,fv,fe

compare = benchmark.Compare(results)
compare.print()
2 Likes

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.